In [23]:
import pandas as pd
import numpy as np
from setfit.modeling import SetFitBaseModel, SKLearnWrapper, sentence_pairs_generation
from sentence_transformers import SentenceTransformer, InputExample, losses
from sklearn.linear_model import LogisticRegression
from transformers import pipeline, AutoTokenizer, AutoModel, AlbertForMaskedLM, RobertaForMaskedLM, RobertaForSequenceClassification

from datasets import Dataset, load_dataset
import copy
from torch.utils.data import DataLoader
import math

In [24]:
def sentence_generates_achor(sentences, labels, template_dict, LABELS, input_pair):

#InputExample(texts=['Anchor 2', 'Positive 2', 'Negative 2'])]
    
    for first_idx in range(len(sentences)):
        current_sentence = sentences[first_idx]
        current_label = labels[first_idx]
        # get the achor template
        anchor = template_dict[current_label]
        second_sentence=np.random.choice(np.array(sentences)[np.array(labels)!=current_label])
        

        input_pair.append(InputExample(texts=[anchor, current_sentence, second_sentence]))
          
    return input_pair

def generate_anchors(template: str, labels: list) -> dict:
    dicts = {}
    for i in range(len(labels)):
        dicts[i] = template + labels[i]
    
    return dicts

class BaseModel:
    def __init__(self,  model, max_seq_length: int, add_normalization_layer: bool) -> None:
        self.model = SentenceTransformer(model)
        self.model_original_state = copy.deepcopy(self.model.state_dict())
        self.model.max_seq_length = max_seq_length

        if add_normalization_layer:
            self.model._modules["2"] = models.Normalize()
        

        
class RunFewShot:
    def __init__(self) -> None:
        # Configure loss function
        self.loss_class = losses.TripletLoss
        # hyperparamiter
        self.margin =0.25
        self.max_seq_length=128
        self.num_itaration=20

        
        self.model_wrapper = BaseModel(
            "paraphrase-mpnet-base-v2", max_seq_length=self.max_seq_length, add_normalization_layer=False
        )
        self.model = self.model_wrapper.model
                
                    
    def get_classifier(self, sbert_model: SentenceTransformer) -> SKLearnWrapper:
        classifier = LogisticRegression()        
        return SKLearnWrapper(sbert_model, classifier)

    def train(self, data: Dataset, LABELS: dict, template: str) -> SKLearnWrapper:
        
        self.model.load_state_dict(copy.deepcopy(self.model_wrapper.model_original_state))

        x_train = data["text"]
        y_train = data["label"]        

        # sentence-transformers adaptation
        batch_size = 16
        
        ## Add TripeLoss
        train_loss = self.loss_class(
            model=self.model,
            distance_metric=losses.TripletDistanceMetric.COSINE,
            triplet_margin=self.margin,
                    )
        train_examples = []
        for _ in range(self.num_itaration):
            # cahnges how to generate input pairs to fit achor
            dict_templates=generate_anchors(template,LABELS)
            train_examples = sentence_generates_achor(np.array(x_train), y_train, dict_templates, LABELS, train_examples)
        
        train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)
        train_steps = len(train_dataloader)

        
        warmup_steps = math.ceil(train_steps * 0.1)
        self.model.fit(
            train_objectives=[(train_dataloader, train_loss)],
            epochs=1,
            steps_per_epoch=train_steps,
            warmup_steps=warmup_steps,
            show_progress_bar=False,
        )
        

        # Train the final classifier
        classifier = self.get_classifier(self.model)
        classifier.fit(x_train, y_train)
        return classifier


### Train

In [97]:
# load your csv file or other, here is jsut a small hand crafted data example (some samples from SST2)

text = ['lend some dignity to a dumb story', 'The plot is nothing but boilerplate clichés from start to finish','I hate over happy endings', 'Grant and Bullock are so good together', 'Emma Watson really fulfilled the role', ' The actoring was really good']
label = [1,1,1,0,0,0]
dict_label =  {0:'positve', 1:'negative'}
df_train = pd.DataFrame({'text': text,'label':label})

In [98]:
data_train = Dataset.from_pandas(df_train, preserve_index=False)

In [99]:
template = 'The movie review is '
label_names = ['positive', 'negative']

In [100]:
ancSetfit = RunFewShot()

In [101]:
classifier = ancSetfit .train(data_train, label_names, template)


### predict

In [109]:
# load your csv file, here is jsut a small data example
text_pred = ['This wat not the best actoring from Mads Mikkelsen', 'the story line is very weak', 'Mads Mikkelsen was really a good choice', 'the story line is really cool']
df_pred = pd.DataFrame({'text': text_pred})
data_pred = Dataset.from_pandas(df_pred, preserve_index=False)

In [110]:


y_pred = classifier.predict(data_pred['text'])

In [111]:
df_pred['pred_sentiment'] = pd.Series(y_pred).map(dict_label)

In [112]:
df_pred

Unnamed: 0,text,pred_sentiment
0,This wat not the best actoring from Mads Mikke...,negative
1,the story line is very weak,negative
2,Mads Mikkelsen was really a good choice,positve
3,the story line is really cool,positve


### What if we have the same training data, but the 1,0 is not about sentiment but topic 
Now imagine that with the small training data, the thing we wanted to classify from the 0 and 1 where not sentiment, but in fact wheter the movie review was concern with actors or plot?
Try to classify this ny only changing the anchor statement

In [113]:
# new anchor statement
dict_label_topic =  {0:'actors', 1: 'plot'}
template = 'The movie review is concerning '
label_names = ['the actors', 'the plot']
# train
classifier_topic = ancSetfit .train(data_train, label_names, template)

# predict
y_pred_topic = classifier_topic.predict(data_pred['text'])
df_pred['pred_topic'] = pd.Series(y_pred_topic).map(dict_label_topic)

# print


In [114]:
df_pred['pred_topic'] = pd.Series(y_pred_topic).map(dict_label_topic)


In [115]:
df_pred

Unnamed: 0,text,pred_sentiment,pred_topic
0,This wat not the best actoring from Mads Mikke...,negative,actors
1,the story line is very weak,negative,plot
2,Mads Mikkelsen was really a good choice,positve,actors
3,the story line is really cool,positve,plot
