# Init

In [None]:
!pip install sentence_transformers
!pip install setfit

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentence_transformers
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 KB[0m [31m812.5 kB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.26.1-py3-none-any.whl (6.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m60.0 MB/s[0m eta [36m0:00:00[0m
Collecting sentencepiece
  Downloading sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m30.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting huggingface-hub>=0.4.0
  Downloading huggingface_hub-0.12.1-py3-none-any.whl (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.

In [None]:
from sklearn.metrics import accuracy_score, f1_score
from sklearn.linear_model import LogisticRegression
from sentence_transformers import SentenceTransformer, InputExample, losses, models, datasets, evaluation
from torch.utils.data import DataLoader
from datasets import load_dataset
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt

# import warnings filter
from warnings import simplefilter
# ignore all future warnings
simplefilter(action='ignore', category=FutureWarning)

import pandas as pd
import numpy as np

import torch
import random
import torch

def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)

In [None]:
def sentence_pairs_generation(sentences, labels, pairs):
	# initialize two empty lists to hold the (sentence, sentence) pairs and
	# labels to indicate if a pair is positive or negative

  numClassesList = np.unique(labels)
  idx = [np.where(labels == i)[0] for i in numClassesList]

  for idxA in range(len(sentences)):      
    currentSentence = sentences[idxA]
    label = labels[idxA]
    idxB = np.random.choice(idx[np.where(numClassesList==label)[0][0]])
    posSentence = sentences[idxB]
		  # prepare a positive pair and update the sentences and labels
		  # lists, respectively
    pairs.append(InputExample(texts=[currentSentence, posSentence], label=1.0))

    negIdx = np.where(labels != label)[0]
    negSentence = sentences[np.random.choice(negIdx)]
		  # prepare a negative pair of images and update our lists
    pairs.append(InputExample(texts=[currentSentence, negSentence], label=0.0))
  
	# return a 2-tuple of our image pairs and labels
  return (pairs)

In [None]:
#SST-2
# Load SST-2 dataset into a pandas dataframe.

dataset1 = load_dataset("SetFit/toxic_conversations_50k")
dataset2 = load_dataset("SetFit/tweet_eval_stance_abortion")
dataset3 = load_dataset("SetFit/catalonia_independence_es")





  0%|          | 0/2 [00:00<?, ?it/s]



  0%|          | 0/2 [00:00<?, ?it/s]



  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
def load(dataset1):
  # Select N examples per class (8 in this case)
  train_df = dataset1["train"].shuffle(seed=42).select(range(10 * 10))
  eval_df = dataset1["test"]

  train_df = pd.DataFrame(data=train_df).rename(columns={'text': 0, 'label': 1})[[0,1]]

  # Load the test dataset into a pandas dataframe.
  eval_df = pd.DataFrame(data=eval_df).rename(columns={'text': 0, 'label': 1})[[0,1]]

  text_col=train_df.columns.values[0] 
  category_col=train_df.columns.values[1]



  x_eval = eval_df[text_col].values.tolist()
  y_eval = eval_df[category_col].values.tolist()

  return train_df,eval_df,text_col,category_col,x_eval,y_eval

# SetFit

In [None]:
def run(train_df,eval_df,text_col,category_col,x_eval,y_eval):
  #@title SetFit
  st_model = 'paraphrase-mpnet-base-v2' #@param ['paraphrase-mpnet-base-v2', 'all-mpnet-base-v1', 'all-mpnet-base-v2', 'stsb-mpnet-base-v2', 'all-MiniLM-L12-v2', 'paraphrase-albert-small-v2', 'all-roberta-large-v1']
  num_training = 8 #@param ["8", "16", "32", "54", "128", "256", "512"] {type:"raw"}
  num_itr = 5 #@param ["1", "2", "3", "4", "5", "10"] {type:"raw"}
  plot2d_checkbox = True #@param {type: 'boolean'}

  set_seed(0)
  # Equal samples per class training
  train_df_sample = pd.concat([train_df[train_df[1]==0].sample(num_training), train_df[train_df[1]==1].sample(num_training)])
  x_train = train_df_sample[text_col].values.tolist()
  y_train = train_df_sample[category_col].values.tolist()

  train_examples = [] 
  for x in range(num_itr):
    train_examples = sentence_pairs_generation(np.array(x_train), np.array(y_train), train_examples)

  orig_model = SentenceTransformer(st_model)
  model = SentenceTransformer(st_model)

  # S-BERT adaptation 
  train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
  train_loss = losses.CosineSimilarityLoss(model)
  model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=1, warmup_steps=10, show_progress_bar=True)

  # No Fit
  X_train_noFT = orig_model.encode(x_train)
  X_eval_noFT = orig_model.encode(x_eval)

  sgd =  LogisticRegression()
  sgd.fit(X_train_noFT, y_train)
  y_pred_eval_sgd = sgd.predict(X_eval_noFT)

  print('Acc. No Fit', accuracy_score(y_eval, y_pred_eval_sgd))

  # With Fit (SetFit)
  X_train = model.encode(x_train)
  X_eval = model.encode(x_eval)

  sgd =  LogisticRegression()
  sgd.fit(X_train, y_train)
  y_pred_eval_sgd = sgd.predict(X_eval)

  print('Acc. SetFit', accuracy_score(y_eval, y_pred_eval_sgd))



In [None]:
train_df,eval_df,text_col,category_col,x_eval,y_eval=load(dataset1)
run(train_df,eval_df,text_col,category_col,x_eval,y_eval)

train_df,eval_df,text_col,category_col,x_eval,y_eval=load(dataset2)
run(train_df,eval_df,text_col,category_col,x_eval,y_eval)

train_df,eval_df,text_col,category_col,x_eval,y_eval=load(dataset3)
run(train_df,eval_df,text_col,category_col,x_eval,y_eval)






                                                    0  1
0   Do you have any facts, data, examples, studies...  0
1   "The Catholic Church, in general, has been slo...  0
2   Trump will envy Putin even more if this lawsui...  0
3   The reason he is owed compensation is that Can...  0
4   It serves to show just how slow and burdensome...  0
..                                                ... ..
95  THE SUBWAY MAYORS - Ford and Tory\n \nFord and...  0
96  Let me guess, you're a white male, and althoug...  1
97                    Ugh...hell on earth experience.  0
98                     My deepest ALOHA to his ohana.  0
99  I thought she did a decent job. Her voice both...  0

[100 rows x 2 columns]


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10 [00:00<?, ?it/s]

Acc. No Fit 0.51404




Acc. SetFit 0.54336
                                                    0  1
0   @user @user @user Abortion Workers Charged wit...  1
1   We should not deny the basic human right to li...  1
2   @user because it's not your baby, body or deci...  2
3   #cogar @user get more #questions on #d #ballot...  0
4   @user -there should be a "stigma" to butcherin...  1
..                                                ... ..
95  @user No, I can't explain why you would consid...  2
96  @user When we #PrayTheRosary lest we forget to...  0
97  There is something very sinister, NAZI-esque, ...  1
98  @user "if u had been my wife I would have blow...  0
99  Americans clearly support family planning. Cut...  2

[100 rows x 2 columns]


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10 [00:00<?, ?it/s]

Acc. No Fit 0.5809248554913294




Acc. SetFit 0.5953757225433526
                                                    0  1
0   RT @NotMadeOfClay: Música y mirar las estrella...  1
1   Hola @sanchezcastejon Si no sabes para que te ...  1
2   Cs, ese partido que  ?pacta con Vox en Andaluc...  1
3   RT @Esp_Interativo: É isso o que a gente quer!...  2
4   🔴DIRECTO | Vicent Nos fue contactado por Puigd...  2
..                                                ... ..
95  @No_T_Calientes @rrincon84 @JesusCablegui2 @La...  1
96  A ojos de quién? De coscubiela, arrimadas y su...  1
97  El Ejército venezolano bloquea la entrada de a...  2
98  Dejando caer a este gobierno, el independentis...  2
99  @Elisabobergpast @FrayJosepho Todos los huérfa...  0

[100 rows x 2 columns]


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10 [00:00<?, ?it/s]

Acc. No Fit 0.4503968253968254
Acc. SetFit 0.45634920634920634
