# Installing packages

In [None]:
!pip install lottery-ticket-pruner
!pip install matplotlib
!pip install --user torch
!pip install pandas
!pip install -q transformers
!pip install --user tqdm
!pip install gdown

***Warning***: Depending on the runtime used, you might have to restart the kernel in order for the new libraries to be located properly.

# Import libraries

In [None]:
import tensorflow as tf
from tensorflow import keras
import pandas as pd
import numpy as np
from tqdm import tqdm
from lottery_ticket_pruner import LotteryTicketPruner
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from transformers import AutoTokenizer
import torch
from transformers import BertTokenizer, TFBertForSequenceClassification,AutoModel, AutoTokenizer
from tqdm import tqdm
from pathlib import Path
import pickle
import gdown

Download tokenized tweets from google drive. If the automatic download fails, please copy-paste the link in the browser, download them, and upload them manually.

In [None]:
files={
    "attention_masks.pkl": "https://drive.google.com/uc?id=1Th1ry1vnFkRsvjNU2s-bzLeeluHBIY7o",
    "sentiments_encoded.pkl": "https://drive.google.com/uc?id=1pblp6oH4Uz125bZx5jcMK-Wq0Bf3TO0J",
    "token_types.pkl": "https://drive.google.com/uc?id=1bKn1U9h83kLQQGwiKTDAKVLB_csqAG_J",
    "tweets_encoded.pkl": "https://drive.google.com/uc?id=1ZjHOZhml728fz_xa-3AGg9wysTNqjC1K",
}
for fname, url in files.items():
    gdown.download(url, fname)

# Load the tokenized tweets

In [None]:
def load_from_pickle():
    with open('tweets_encoded.pkl', 'rb') as f:
        tweets_encoded = pickle.load(f)
        print("Loaded tweets_encoded")
    with open('attention_masks.pkl', 'rb') as f:
        attention_masks = pickle.load(f)
        print("Loaded attention_masks")
    with open('token_types.pkl', 'rb') as f:
        token_types = pickle.load(f)
        print("Loaded token_types")
    with open('sentiments_encoded.pkl', 'rb') as f:
        sentiments_encoded = pickle.load(f)
        print("Loaded sentiments_encoded")

    return tweets_encoded, attention_masks, token_types, sentiments_encoded

In [None]:
tweets_encoded, attention_masks, token_types, sentiments_encoded = load_from_pickle()

In [None]:
tweets, sentiments = [tweets_encoded, attention_masks, token_types], sentiments_encoded

# Define max length of embedding and the batch size

In [None]:
no_gpu = 8
max_length = 512 #of embedding
batch_size = 40 * no_gpu

# Helper function


In [None]:
def map_tweet_to_dict(input_ids, attention_masks, token_type_ids, label):
  return {
      "input_ids": input_ids,
      "token_type_ids": token_type_ids,
      "attention_mask": attention_masks,
  }, label

# Make train, validation, and test data
Here we have: 10% train, 5% validation, and 5% test due to time constraints.

In [None]:
def make_train_data(start_tweets, start_sentiments):
  print("I start with:"+str(len(start_sentiments)))
  np.random.seed(33) 
  msk = np.random.rand(len(start_sentiments)) < 0.2
  tweets = []
  sentiments = []

  a,b,c = [], [], []
  for i in range(len(start_sentiments)):
    if msk[i] == True:
      a.append(start_tweets[0][i])
      b.append(start_tweets[1][i])
      c.append(start_tweets[2][i])
      sentiments.append(start_sentiments[i])
  tweets = [a,b,c]

  print("I continue with:"+str(len(sentiments))) 
  ratio = 0.5
  pos, neg = sentiments[: sentiments.count(1)], sentiments[- sentiments.count(0):]
  msk1 = np.random.rand(len(pos)) < ratio
  msk2 = np.random.rand(len(neg)) < ratio

  train_tweets = [[], [], []]
  train_sentiments = []

  validation_tweets = [[], [], []]
  validation_sentiments = []

  test_tweets = [[], [], []]
  test_sentiments = []

  rest_tweets = [[], [], []] 
  rest_sentiments = []

  for i in range(len(pos)):
    if msk1[i] == True:
      train_tweets[0].append(tweets[0][i]) 
      train_tweets[1].append(tweets[1][i]) 
      train_tweets[2].append(tweets[2][i]) 
      train_sentiments.append(sentiments[i])
    else:
      rest_tweets[0].append(tweets[0][i]) 
      rest_tweets[1].append(tweets[1][i]) 
      rest_tweets[2].append(tweets[2][i]) 
      rest_sentiments.append(sentiments[i])
        
  for i in range(len(neg)):
    if msk2[i] == True:
      train_tweets[0].append(tweets[0][len(pos) + i])
      train_tweets[1].append(tweets[1][len(pos) +i]) 
      train_tweets[2].append(tweets[2][len(pos) +i])  
      train_sentiments.append(sentiments[len(pos) +i])
    else:
      rest_tweets[0].append(tweets[0][len(pos) +i])
      rest_tweets[1].append(tweets[1][len(pos) +i])
      rest_tweets[2].append(tweets[2][len(pos) +i]) 
      rest_sentiments.append(sentiments[len(pos) +i])



  rest_ratio = 0.5
  pos, neg = rest_sentiments[: rest_sentiments.count(1)], rest_sentiments[- rest_sentiments.count(0):]
  msk1 = np.random.rand(len(pos)) < rest_ratio
  msk2 = np.random.rand(len(neg)) < rest_ratio

  
  for i in range(len(pos)):
    if msk1[i] == True:
      validation_tweets[0].append(rest_tweets[0][i])
      validation_tweets[1].append(rest_tweets[1][i])
      validation_tweets[2].append(rest_tweets[2][i]) 
      validation_sentiments.append(rest_sentiments[i])
    else:
      test_tweets[0].append(rest_tweets[0][i]) 
      test_tweets[1].append(rest_tweets[1][i]) 
      test_tweets[2].append(rest_tweets[2][i]) 
      test_sentiments.append(rest_sentiments[i])

  for i in range(len(neg)):
    if msk2[i] == True:
      validation_tweets[0].append(rest_tweets[0][len(pos) + i]) 
      validation_tweets[1].append(rest_tweets[1][len(pos) + i]) 
      validation_tweets[2].append(rest_tweets[2][len(pos) + i]) 
      validation_sentiments.append(rest_sentiments[len(pos) + i])
    else:
      test_tweets[0].append(rest_tweets[0][len(pos) + i]) 
      test_tweets[1].append(rest_tweets[1][len(pos) + i]) 
      test_tweets[2].append(rest_tweets[2][len(pos) + i]) 
      test_sentiments.append(rest_sentiments[len(pos) + i])

  print("I have: Train:"+str(len(train_sentiments)) + " Validation:" +str(len(validation_sentiments)) + " Test:" +str(len(test_sentiments))) 

  train_tweets_ds = tf.data.Dataset.from_tensor_slices((train_tweets[0], train_tweets[1], train_tweets[2], train_sentiments)).map(map_tweet_to_dict).shuffle(len(train_sentiments)).batch(batch_size)
  print("Train loaded")
  validation_tweets_ds = tf.data.Dataset.from_tensor_slices((validation_tweets[0], validation_tweets[1],validation_tweets[2], validation_sentiments)).map(map_tweet_to_dict).shuffle(len(validation_sentiments)).batch(batch_size)
  print("Validation loaded")
  test_tweets_ds = tf.data.Dataset.from_tensor_slices((test_tweets[0], test_tweets[1], test_tweets[2],test_sentiments)).map(map_tweet_to_dict).batch(batch_size)
  print("Test loaded")
  return (train_tweets_ds, train_sentiments), (validation_tweets_ds, validation_sentiments), (test_tweets_ds, test_sentiments)

# Define the BERT model and training procedure #

In [None]:
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from transformers import TFBertForSequenceClassification
import tensorflow as tf


# Create the initial BERT model - from these weigths we start at each one-shot pruning trial and at the begining of iterative pruning
def model_builder():
  model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels = 2)
  return model


# Creates the model which had its weights pruned 
def return_model(model):
    new_model = model_builder()
    new_model.set_weights(model.get_weights())
    return new_model


# The same pruning strategy seen in previous notebooks. We also deploy early stopping for preventing overfitting.
# Note that we save the epoch when training stopped due to early stopping so that we don't go past that when training the pruned networks - as required in LTH
def train_model(modelX, num_epochs, X_train, Y_train, X_eval, Y_eval):
    
    stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience= 3, restore_best_weights= True);
    gpus = tf.config.list_logical_devices('GPU')
    if len(gpus) > 1:
      strategy = tf.distribute.MirroredStrategy(gpus)
      with strategy.scope():

          model = return_model(modelX)

          optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5)
          loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
          metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')

          model.compile(optimizer=optimizer, loss=loss, metrics=[metric])

    history = model.fit(X_train, batch_size= batch_size, epochs= num_epochs, validation_data=X_eval, callbacks=[stop_early, ])

    best_epoch = np.argmax(history.history['val_accuracy']) + 1
    
    return(model,best_epoch);


__For obtaining the class prediction__

In [None]:
# Obtaining the class that was predicted for each tweet from the decisions list containing the softmaxes output
def get_decision(decisions):
  results = []
  for i in range(len(decisions[0])):
    results.append(np.argmax(decisions[0][i]))

  return results 

# Evaluate the results on the test set#

In [None]:
# Evaluate the model/pruned on the test set: Compute the accuracy, TPR and TNR
def evaluate(net,X,Y):
    Z = net.predict(X,batch_size=10)
    Zbin = get_decision(Z);

    Y = np.array(Y)
    Z = np.array(Z)
    Zbin = np.array(Zbin)

    #Compute the acc, tpr, tnr
    n = len(Zbin); 
    n0 = np.sum(Y==0); 
    n1 = np.sum(Y==1);

    acc = np.sum(Zbin == Y) / n;
    tpr = np.sum(Zbin[Y==1]) / n1;
    tnr = np.sum(Zbin[Y==0] == 0) / n0;

    print("Acc: " + str(acc)+ str("\t") + 
          "Tpr:" + str(tpr) + str("\t") +
          "Tnr:" + str(tnr) + str("\t"));
  
    return(acc, tpr, tnr); 

# Create a plot with values for the baseline, pruned with one-shot, and pruned with iterative model #
The values can be Accuracy, TPR, or TNR

In [None]:
def create_figure(title, x_label, y_label, x0, y0, xy0err, x1, y1, xy1err, x2, y2, xy2err):
    
    fig, ax =  plt.subplots(figsize=(10,10));
    
    #trained baseline model 
    err= ax.errorbar(x0,y0,color = 'k', yerr=xy0err, ecolor='k', elinewidth=1, capsize=3);
    err[-1][0].set_linestyle('--'); 
    
    #one shot from initial
    err=ax.errorbar(x1,y1, color ='r',yerr=xy1err, ecolor="r", elinewidth=1, capsize=3);
    err[-1][0].set_linestyle('--');
    
    #iterative from initial
    err=ax.errorbar(x2,y2, color='b', yerr=xy2err, ecolor="b", elinewidth=1, capsize=3);
    err[-1][0].set_linestyle('--');
  
    plt.xlabel(x_label);
    plt.ylabel(y_label);


# Prune a model using the one-shot or iterative pruning method for multiple values of *p%* and train it #

In [None]:
def apply_pruning_method(pruning_method, num_repetition, X_train, Y_train, X_eval, Y_eval, X_test, Y_test, num_epochs, num_pruning_trials, initial_baseline_model, initial_baseline_model_weights, trained_baseline_model, trained_baseline_model_weights, prune_strategy, prune_percentage_for_iterative):

    #Instantiate the pruner
    pruner = LotteryTicketPruner(initial_baseline_model);

    #Store the accs, tprs, tnrs for each pruning trial (7 in total)
    pruned_accs =[];
    pruned_tprs = [];
    pruned_tnrs = [];
    
    pruning_trial_no=0;
    

    #Conduct numtiple pruning trials 
    for prune_percentage in list(np.linspace(start=0.3,stop=1,num=num_pruning_trials+1))[0:-1]:
        
        pruning_trial_no= pruning_trial_no+1;
        
        #Decide from what mask you want to start (new one for one-shot, previous one for iterative)
        if pruning_method=="one_shot":
          pruner.reset_masks();
        
        elif pruning_method=="iterative":
          prune_percentage = prune_percentage_for_iterative;
          
        #Set the initial weights and the ones after training    
        initial_baseline_model.set_weights(initial_baseline_model_weights);
        trained_baseline_model.set_weights(trained_baseline_model_weights);

        #Here the mask is obtained by removing prune_percentage of the trained_baseline_model's weights
        pruner.calc_prune_mask(trained_baseline_model, prune_percentage, prune_strategy);
        
        #Prune the initial_baseline_model of its weights using the above-computed mask (and starting from the same initial weights)
        pruner.apply_pruning(initial_baseline_model);

        
        #Train the above-pruned initial_baseline_model for at most the same number of epochs as the trained_baseline_model
        (pruned_trained_model, stoped_epoch) = train_model(initial_baseline_model, num_epochs, X_train, Y_train, X_eval, Y_eval);

        print(pruning_method + " Experiment repetition no:"+ str(num_repetition)+ " Pruned model at " + str(prune_percentage) + "pruning trial:"+ str(pruning_trial_no)+"\n");
        
        #Evaluate the pruned_trained_model
        (acc, tpr, tnr) = evaluate(pruned_trained_model, X_test, Y_test);

        pruned_accs.append(acc);
        pruned_tprs.append(tpr);
        pruned_tnrs.append(tnr);

    #Return accs, tprs, tnrs for the (7) pruning trials
    print(pruning_method + " After pruning at all percentages:")
    print("Accs:"+ str(pruned_accs))
    print("TPRS:"+ str(pruned_tprs))
    print("TNRS:"+ str(pruned_tnrs))
    return(pruned_accs, pruned_tprs, pruned_tnrs)

# Repeat:  Prune a model using the one-shot or iterative pruning method for multiple values of *p%*, train it, and evaluate the results #

In [None]:
def repeat_experiment(num_epochs=20, num_experiment_repetitions=5, num_pruning_trials= 9, prune_strategy="smallest_weights_global", prune_percentage_for_iterative=0.2):
    
    
    #Accs, tprs, tnrs for all (3) repetitions and (7) pruning trials for the baseline, pruned with one-shot and pruned with iterative models
    baseline_accuracies_all_trials=[];
    baseline_tprs_all_trials=[];
    baseline_tnrs_all_trials=[];

    pruned_accuracies_all_trials_one_shot_from_initial=[[] for x in range(num_pruning_trials)];
    pruned_tprs_all_trials_one_shot_from_initial=[[] for x in range(num_pruning_trials)];
    pruned_tnrs_all_trials_one_shot_from_initial=[[] for x in range(num_pruning_trials)];

    pruned_accuracies_all_trials_iterative_from_initial=[[] for x in range(num_pruning_trials)];
    pruned_tprs_all_trials_iterative_from_initial=[[] for x in range(num_pruning_trials)];
    pruned_tnrs_all_trials_iterative_from_initial=[[] for x in range(num_pruning_trials)];

   
    #Repeat the experiment multiple times 3
    for repetition in range(0,num_experiment_repetitions):
        print("****************")
        print("Experiment repetition:" + str(repetition))
        print("****************")

        #Compute the dataset 
        (X_train, Y_train), (X_eval, Y_eval), (X_test, Y_test) = make_train_data(tweets, sentiments)
        print("Dataset computed")
        
        #Build the initial_baseline_model and save its weights   
        initial_baseline_model = model_builder();
        initial_baseline_model_weights = initial_baseline_model.get_weights();

        #Train the initial_baseline_model and save the trained model's weights   
        (trained_baseline_model,trnd_bm_stop_epoch) = train_model(initial_baseline_model, num_epochs, X_train, Y_train, X_eval, Y_eval);
        trained_baseline_model_weights = trained_baseline_model.get_weights();
       
        #Evaluate the trained_baseline_model 
        print("-------- Baseline model results at repetition "+str(repetition)+":\n");
        (b_acc, b_tpr, b_tnr) =evaluate(trained_baseline_model, X_test, Y_test);
        
        #Save the trained_baseline_model's acc, tpr, tnr
        baseline_accuracies_all_trials.append(b_acc);
        baseline_tprs_all_trials.append(b_tpr);
        baseline_tnrs_all_trials.append(b_tnr);

        #Conduct one-shot pruning one time, for multiple pruning percentages (7 trials) 
        print("-------- Starting One-Shot with initialization from original initial network:\n");
        (oneShot_from_initial_pruned_accs, oneShot_from_initial_pruned_tprs, oneShot_from_initial_pruned_tnrs)=apply_pruning_method("one_shot", repetition, X_train, Y_train, X_eval, Y_eval, X_test, Y_test, trnd_bm_stop_epoch, num_pruning_trials, initial_baseline_model, initial_baseline_model_weights, trained_baseline_model, trained_baseline_model_weights, prune_strategy, prune_percentage_for_iterative);
        
        #Conduct iterative pruning one time, for multiple pruning percentages (7 trials) 
        print("-------- Starting Iterative with initialization from original initial network:\n");
        (iterative_from_initial_pruned_accs, iterative_from_initial_pruned_tprs, iterative_from_initial_pruned_tnrs)=apply_pruning_method("iterative", repetition, X_train, Y_train, X_eval, Y_eval, X_test, Y_test, trnd_bm_stop_epoch, num_pruning_trials, initial_baseline_model, initial_baseline_model_weights, trained_baseline_model, trained_baseline_model_weights, prune_strategy, prune_percentage_for_iterative);
  
        #Save the accs, tprs, tnrs from the above 3*7 pruning trials
        for i in range(0,num_pruning_trials):
            pruned_accuracies_all_trials_one_shot_from_initial[i].append(oneShot_from_initial_pruned_accs[i]);
            pruned_tprs_all_trials_one_shot_from_initial[i].append(oneShot_from_initial_pruned_tprs[i]);
            pruned_tnrs_all_trials_one_shot_from_initial[i].append(oneShot_from_initial_pruned_tnrs[i]);

            pruned_accuracies_all_trials_iterative_from_initial[i].append(iterative_from_initial_pruned_accs[i]);
            pruned_tprs_all_trials_iterative_from_initial[i].append(iterative_from_initial_pruned_tprs[i]);
            pruned_tnrs_all_trials_iterative_from_initial[i].append(iterative_from_initial_pruned_tnrs[i]);
            
    #Return the accs, tprs, tnrs for the baseline, pruned with one-shot, pruned with iterative models.    
    return (    baseline_accuracies_all_trials, baseline_tprs_all_trials, baseline_tnrs_all_trials,
                pruned_accuracies_all_trials_one_shot_from_initial, pruned_tprs_all_trials_one_shot_from_initial, pruned_tnrs_all_trials_one_shot_from_initial, 
                pruned_accuracies_all_trials_iterative_from_initial, pruned_tprs_all_trials_iterative_from_initial, pruned_tnrs_all_trials_iterative_from_initial,
           );
            
            

Compute pruning percentage (*p%*) values for iterative pruning

In [None]:
def compute_pruning_percentages_iterative(num_pruning_trials, pruning_percentage):
  
  remaining_weights=1;
  percentages=[];

  for i in range(0, num_pruning_trials):
    remaining_weights= remaining_weights - remaining_weights*pruning_percentage;
    percentages.append(1- remaining_weights);

  return percentages;

# Create the Accuracy, TPR and TNR plots #

In [None]:
def create_plots_for_executed_experiments(num_experiment_repetitions, num_pruning_trials, prune_percentage_for_iterative, b_accs, b_tprs, b_tnrs, os_init_accs, os_init_tprs, os_init_tnrs, it_init_accs, it_init_tprs, it_init_tnrs):

      #Compute pruning percentages for both experiments (7 pruning percentages)  
      pruning_percentages_one_shot = list(np.linspace(start=0.3,stop=1,num=num_pruning_trials+1))[0:-1];
      pruning_percentages_iterative = compute_pruning_percentages_iterative(num_pruning_trials, prune_percentage_for_iterative);

      #Store the accs, tprs, tnrs per experiment repetition
      b_accs0 =[];
      b_accs0err=[[],[]];
      b_tprs0 =[];
      b_tprs0err = [[],[]];
      b_tnrs0 = [];
      b_tnrs0err = [[],[]]; 
        
    
      os_init_accs1= [];
      os_init_accs1err= [[],[]];

      os_init_tprs1= [];
      os_init_tprs1err= [[],[]];

      os_init_tnrs1= [];
      os_init_tnrs1err= [[],[]];
      

      it_init_accs3= [];
      it_init_accs3err= [[],[]];

      it_init_tprs3= [];
      it_init_tprs3err= [[],[]];

      it_init_tnrs3= [];
      it_init_tnrs3err= [[],[]];
      
      #Put the accs, tprs, tnrs and error bars in the corresponding lists for making the plots for the baseline model  
      b_accs0 = [np.mean(b_accs)] *  num_pruning_trials;
      b_accs0err[0] = [np.mean(b_accs)-np.min(b_accs)] * num_pruning_trials;
      b_accs0err[1] = [np.max(b_accs) -np.mean(b_accs)] * num_pruning_trials;

      b_tprs0 = [np.mean(b_tprs)] * num_pruning_trials;
      b_tprs0err[0] = [np.mean(b_tprs)-np.min(b_tprs)]* num_pruning_trials;
      b_tprs0err[1] = [np.max(b_tprs) -np.mean(b_tprs)]* num_pruning_trials;

      b_tnrs0 = [np.mean(b_tnrs)] * num_pruning_trials;
      b_tnrs0err[0] = [np.mean(b_tnrs)-np.min(b_tnrs)] * num_pruning_trials;
      b_tnrs0err[1] = [np.max(b_tnrs)-np.mean(b_tnrs)] * num_pruning_trials;



      #Put the accs, tprs, tnrs and error bars in the corresponding lists for making the plots for the pruned with one-shot and iterative model
      for repetition in range(0, num_pruning_trials):
        
          os_init_accs1.append(np.mean(os_init_accs[repetition]));
          os_init_accs1err[0].append(np.mean(os_init_accs[repetition])- np.min(os_init_accs[repetition]));
          os_init_accs1err[1].append(np.max(os_init_accs[repetition]) - np.mean(os_init_accs[repetition]));

          os_init_tprs1.append(np.mean(os_init_tprs[repetition]));
          os_init_tprs1err[0].append(np.mean(os_init_tprs[repetition]) - np.min(os_init_tprs[repetition]));
          os_init_tprs1err[1].append(np.max(os_init_tprs[repetition]) - np.mean(os_init_tprs[repetition]));

          os_init_tnrs1.append(np.mean(os_init_tnrs[repetition]));
          os_init_tnrs1err[0].append(np.mean(os_init_tnrs[repetition])- np.min(os_init_tnrs[repetition]));
          os_init_tnrs1err[1].append(np.max(os_init_tnrs[repetition]) -np.mean(os_init_tnrs[repetition]));
        

          it_init_accs3.append(np.mean(it_init_accs[repetition]));
          it_init_accs3err[0].append(np.mean(it_init_accs[repetition]) - np.min(it_init_accs[repetition]));
          it_init_accs3err[1].append(np.max(it_init_accs[repetition]) - np.mean(it_init_accs[repetition]));

          it_init_tprs3.append(np.mean(it_init_tprs[repetition]));
          it_init_tprs3err[0].append(np.mean(it_init_tprs[repetition])-np.min(it_init_tprs[repetition]));
          it_init_tprs3err[1].append(np.max(it_init_tprs[repetition]) - np.mean(it_init_tprs[repetition]));

          it_init_tnrs3.append(np.mean(it_init_tnrs[repetition]));
          it_init_tnrs3err[0].append(np.mean(it_init_tnrs[repetition])-np.min(it_init_tnrs[repetition]));
          it_init_tnrs3err[1].append(np.max(it_init_tnrs[repetition]) - np.mean(it_init_tnrs[repetition]));


   
      #Create the acc, tpr, tnr figures for the baseline, pruned with one-shot and iterative models
      create_figure("Accuracies","Pruned ratio", "Accuracy", pruning_percentages_one_shot, b_accs0, b_accs0err, pruning_percentages_one_shot, os_init_accs1, os_init_accs1err, pruning_percentages_iterative, it_init_accs3, it_init_accs3err);
      create_figure("TPRS","Pruned ratio", "TPR", pruning_percentages_one_shot, b_tprs0, b_tprs0err, pruning_percentages_one_shot, os_init_tprs1, os_init_tprs1err, pruning_percentages_iterative, it_init_tprs3, it_init_tprs3err);
      create_figure("TNRS","Pruned ratio", "TNR", pruning_percentages_one_shot, b_tnrs0, b_tnrs0err, pruning_percentages_one_shot, os_init_tnrs1, os_init_tnrs1err, pruning_percentages_iterative, it_init_tnrs3, it_init_tnrs3err);
    

# Run the experiment #

In [None]:
num_epochs=10;
num_experiment_repetitions=3;
num_pruning_trials= 7;

prune_strategy="smallest_weights_global";
prune_percentage_for_iterative=0.3;

In [None]:
(   b_accs, b_tprs, b_tnrs,
    os_init_accs, os_init_tprs, os_init_tnrs,
    it_init_accs, it_init_tprs, it_init_tnrs) = repeat_experiment(num_epochs, num_experiment_repetitions, num_pruning_trials,prune_strategy, prune_percentage_for_iterative);

# Create the plots #

In [None]:
create_plots_for_executed_experiments(num_experiment_repetitions, num_pruning_trials, prune_percentage_for_iterative, b_accs, b_tprs, b_tnrs, os_init_accs, os_init_tprs, os_init_tnrs, it_init_accs, it_init_tprs, it_init_tnrs)