In [1]:
!pip install sentence-transformers datasets



## Pruning

In [2]:
from sentence_transformers import SentenceTransformer 
distilroberta = SentenceTransformer('stsb-distilroberta-base-v2')

In [3]:
from datasets import load_metric, load_dataset 
stsb_metric = load_metric('glue', 'stsb') 
stsb = load_dataset('glue', 'stsb') 

mrpc_metric = load_metric('glue', 'mrpc') 
mrpc = load_dataset('glue','mrpc')

Reusing dataset glue (/root/.cache/huggingface/datasets/glue/stsb/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Reusing dataset glue (/root/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


In [4]:
import math 
import tensorflow as tf 

def roberta_sts_benchmark(batch): 
    sts_encode1 = tf.nn.l2_normalize(distilroberta.encode(batch['sentence1']),axis=1) 
    sts_encode2 = tf.nn.l2_normalize(distilroberta.encode(batch['sentence2']),axis=1) 
    cosine_similarities = tf.reduce_sum(tf.multiply(sts_encode1, sts_encode2),axis=1) 
    clip_cosine_similarities = tf.clip_by_value(cosine_similarities,-1.0,1.0)
    scores = 1.0 - tf.acos(clip_cosine_similarities) / math.pi 
    return scores

In [5]:
references = stsb['validation'][:]['label'] 

In [6]:
distilroberta_results = roberta_sts_benchmark(stsb['validation']) 

In [7]:
from torch.nn.utils import prune 
pruner = prune.L1Unstructured(amount=0.2) 

In [8]:
state_dict = distilroberta.state_dict()

for key in state_dict.keys(): 
    if "weight" in key: 
        state_dict[key] = pruner.prune(state_dict[key]) 

In [9]:
distilroberta.load_state_dict(state_dict) 

<All keys matched successfully>

In [10]:
distilroberta_results_p = roberta_sts_benchmark(stsb['validation']) 

In [11]:
import pandas as pd 

pd.DataFrame({ 
  "DistillRoberta":stsb_metric.compute(predictions=distilroberta_results, references=references),
  "DistillRobertaPruned":stsb_metric.compute(predictions=distilroberta_results_p, references=references)
}) 

Unnamed: 0,DistillRoberta,DistillRobertaPruned
pearson,0.888461,0.849915
spearmanr,0.889246,0.849125


## Quantization

In [12]:
import torch 

distilroberta = torch.quantization.quantize_dynamic(
    model=distilroberta,
    qconfig_spec = {torch.nn.Linear: torch.quantization.default_dynamic_qconfig}, dtype=torch.qint8)

In [13]:
distilroberta_results_pq = roberta_sts_benchmark(stsb['validation']) 

In [14]:
pd.DataFrame({ 
  "DistillRoberta":stsb_metric.compute(predictions=distilroberta_results, references=references), 
  "DistillRobertaPruned":stsb_metric.compute(predictions=distilroberta_results_p, references=references), 
  "DistillRobertaPrunedQINT8":stsb_metric.compute(predictions=distilroberta_results_pq, references=references) 
})

Unnamed: 0,DistillRoberta,DistillRobertaPruned,DistillRobertaPrunedQINT8
pearson,0.888461,0.849915,0.826784
spearmanr,0.889246,0.849125,0.824857
