<a href="https://colab.research.google.com/github/MorenoSara/Few-Shot_Text_Classification/blob/main/Supervised_Text_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U sentence-transformers

In [None]:
import pandas as pd
from sentence_transformers import SentenceTransformer
import torch
from sentence_transformers.util import cos_sim
import numpy as np
from scipy import sparse
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from torch.utils.data.dataset import Dataset
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train_dataset = pd.read_excel('train.xlsx', index_col=0) # 32889 samples
eval_dataset = pd.read_excel('valid.xlsx', index_col=0)
test_dataset = pd.read_excel('test.xlsx', index_col=0)

In [None]:
REMAP_LEV1 = {'CS': 'Computer Science', 
              'Civil': 'Civil Engineering', 
              'ECE': 'Electrical Engineering', 
              'Psychology': 'Psychology', 
              'MAE': 'Mechanical Engineering', 
              'Medical': 'Medical Science', 
              'biochemistry': 'Biochemistry'}

In [None]:
def get_mapped_labels(data, mapping_dict):
  labels = [l.strip() for l in data]
  return list(map(lambda l: mapping_dict[l], labels))

In [None]:
labels_set = get_mapped_labels(set(train_dataset['Domain']), REMAP_LEV1)

training_docs = train_dataset['Abstract']
training_labels = get_mapped_labels(train_dataset['Domain'], REMAP_LEV1)

eval_docs = eval_dataset['Abstract']
eval_labels = get_mapped_labels(eval_dataset['Domain'], REMAP_LEV1)

test_docs = test_dataset['Abstract']
test_labels = get_mapped_labels(test_dataset['Domain'], REMAP_LEV1)

print(f"Training set: {len(training_docs)}, {len(training_labels)}") # 32889 samples
print(f"Evaluation set: {len(eval_docs)}, {len(eval_labels)}") # 4698 samples
print(f"Test set: {len(test_docs)}, {len(test_labels)}") # 9398 samples

In [None]:
le = LabelEncoder()
integer_labels = le.fit_transform(labels_set)

int_training_labels = le.transform(training_labels)
int_eval_labels = le.transform(eval_labels)
int_test_labels = le.transform(test_labels)

ohe = OneHotEncoder(sparse=False)
ohe.fit(integer_labels.reshape(-1,1))

ohe_training_labels = ohe.transform(int_training_labels.reshape(-1,1)) # (32889, 7)
ohe_eval_labels = ohe.transform(int_eval_labels.reshape(-1,1)) # (4698, 7)
ohe_test_labels = ohe.transform(int_test_labels.reshape(-1,1)) # (9398, 7)

In [None]:
class document_class(Dataset):
  def __init__(self, documents, labels):
    self.train_df = []
    for id, doc in enumerate(documents):
      curr_doc = [labels[id]]
      curr_doc.append(doc) 
      self.train_df.append(curr_doc)
  
  def __getitem__(self, index):
    return self.train_df[index] 

  def __len__(self):
    return len(self.train_df)

In [None]:
def my_collate_fn(batch):
  documents = []
  labels = []
  for doc in batch:
    documents.append(doc[1])
    labels.append(list(doc[0]))
  return (documents, torch.Tensor(labels))

In [None]:
def save_model(model, model_path):
    """Save model."""
    torch.save(model.state_dict(), model_path)

def load_model(model, model_path, use_cuda=True):
    """Load model."""
    map_location = 'cpu'
    if use_cuda and torch.cuda.is_available():
        map_location = 'cuda:0'
    model.load_state_dict(torch.load(model_path, map_location))
    return model

## Train the classifier and maintain the pre-trained sentence transformer 

In [None]:
class complete_model_no_st_finetuning(nn.Module):
  def __init__(self, sentence_transformer_model, st_embedding_dimension, num_classes, device):
    super().__init__()
    self.st = SentenceTransformer(sentence_transformer_model)
    self.classification = nn.Linear(in_features=st_embedding_dimension, out_features=num_classes)
    self.device = device
    self.to(device)

  def forward(self, documents):
    docs = self.st.encode(documents) # exploit pretrained sentence transformer
    probs = self.classification(torch.Tensor(docs).to(self.device)) # assign a score to each class for every document
    return probs

In [None]:
model = complete_model_no_st_finetuning('sentence-transformers/all-mpnet-base-v2', 768, len(labels_set), device)

In [None]:
lr = 1e-2
epochs = 5
batch_size = 256

criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, betas = [0.9, 0.999], eps=1e-8)

training_documents = document_class(training_docs, ohe_training_labels)
training_dataloader = DataLoader(training_documents, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)

eval_documents = document_class(eval_docs, ohe_eval_labels)
eval_dataloader = DataLoader(eval_documents, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)

In [None]:
best_eval_loss = np.inf

for epoch in range(epochs):

  training_loss = 0

  model.train()

  for batch, (docs, labels) in enumerate(training_dataloader):

    labels = labels.to(device)

    probabilities = model(docs)
    loss = criterion(probabilities, labels)
    training_loss += loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f'Batch: {batch}/{len(training_dataloader)}, epoch: {epoch}/{epochs}. Training loss: {training_loss:.3f}.')
    break

  model.eval()
  eval_loss = 0
  
  for eval_batch, (eval_docums, eval_labels) in enumerate(eval_dataloader):
    eval_labels = eval_labels.to(device)

    with torch.no_grad():
      eval_probs = model(eval_docums)

    batch_eval_loss = criterion(eval_probs, eval_labels)
    eval_loss += batch_eval_loss.item()
    print(f'Evaluation: Batch: {eval_batch}/{len(eval_dataloader)}, epoch: {epoch}/{epochs}. Training loss: {eval_loss:.3f}.')
  
  print("\nEvaluation loss: ", eval_loss)
  print('\n')

  if eval_loss < best_eval_loss:
    print("Saving best model")
    best_eval_loss = eval_loss
    save_model(model, './best_model.pkl')


## Train the classifier and finetune the sentence transformer

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, device, model_name: str = 'sentence-transformers/all-mpnet-base-v2') -> None:
        super(TextEncoder, self).__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(device)

    # def forward(self, text: Union[str, List[str]]) -> Tensor:
    def forward(self, text) -> torch.Tensor:
        inp = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt')
        inp = inp.to(device)
        out = self.model(**inp)[0]  # First element of model_output contains all token embeddings.
        out = self.mean_pooling(out, inp['attention_mask'])
        if isinstance(text, str):  # If input is just 1 string -> return 1D embeddings.
            out = out.squeeze(0)
        return nn.functional.normalize(out, p=2, dim=-1)

    def mean_pooling(self, token_embeddings, attention_mask):
        input_mask_expanded = \
            attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

In [None]:
class complete_model_finetuning(nn.Module):
  def __init__(self, sentence_transformer_model, st_embedding_dimension, num_classes, device):
    super().__init__()
    self.st = TextEncoder(device, sentence_transformer_model)
    self.classification = nn.Linear(in_features=st_embedding_dimension, out_features=num_classes)
    self.device = device
    self.to(device)

  def forward(self, documents):
    docs = self.st(documents)
    probs = self.classification(docs) # assign a score to each class for every document
    return probs

In [None]:
model = complete_model_finetuning('sentence-transformers/all-mpnet-base-v2', 768, len(labels_set), device)

In [None]:
lr = 1e-2
epochs = 5
batch_size = 16

criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, betas = [0.9, 0.999], eps=1e-8)

training_documents = document_class(training_docs, ohe_training_labels)
training_dataloader = DataLoader(training_documents, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)

eval_documents = document_class(eval_docs, ohe_eval_labels)
eval_dataloader = DataLoader(eval_documents, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)

In [None]:
best_eval_loss = np.inf

for epoch in range(epochs):

  training_loss = 0

  model.train()

  for batch, (docs, labels) in enumerate(training_dataloader):

    labels = labels.to(device)

    probabilities = model(docs)
    loss = criterion(probabilities, labels)
    training_loss += loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f'Batch: {batch}/{len(training_dataloader)}, epoch: {epoch}/{epochs}. Training loss: {training_loss:.3f}.')
    break

  model.eval()
  eval_loss = 0

  for eval_batch, (eval_docums, eval_labels) in enumerate(eval_dataloader):
    eval_labels = eval_labels.to(device)

    with torch.no_grad():
      eval_probs = model(eval_docums)

    batch_eval_loss = criterion(eval_probs, eval_labels)
    eval_loss += batch_eval_loss.item()
    print(f'Evaluation: Batch: {eval_batch}/{len(eval_dataloader)}, epoch: {epoch}/{epochs}. Training loss: {eval_loss:.3f}.')
  
  print("\nEvaluation loss: ", eval_loss)
  print('\n')

  if eval_loss < best_eval_loss:
    print("Saving best model")
    best_eval_loss = eval_loss
    save_model(model, './finetuned_model.pkl')