In [None]:
# !pip install lightning
# !pip install sentencepiece

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from transformers import AutoModelForSequenceClassification, AutoTokenizer

import pandas as pd
from scipy.stats import spearmanr, pearsonr
from sklearn.model_selection import train_test_split
import os
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import nltk
from nltk.corpus import wordnet
import random
from sklearn.model_selection import KFold

import re

In [None]:
nltk.download('wordnet')

class str_dataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, syn_replace = False, sep = '[SEP]'):
        self.dataframe = dataframe
        self.tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
        self.syn_replace = syn_replace

        self.dataframe['input'] = self.dataframe.apply(lambda row : row['Text'].replace('\n',sep),axis = 1)

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        # Extract the features and target from the DataFrame
        # Adjust this based on your DataFrame structure
        features = self.dataframe['input'].loc[idx]

        if self.syn_replace and random.random() < 0.3:
            features = self.apply_augmentation(features)

        token = self.tokenizer(features, return_tensors='pt', truncation='longest_first', padding='max_length',max_length=265)

        input_ids = token['input_ids'].squeeze()
        attention_mask = token['attention_mask'].squeeze()

        score = self.dataframe['Score'].loc[idx]

        score = torch.tensor(score, dtype=torch.float32).unsqueeze(dim=0)  # Adjust the dtype as needed

        return input_ids, attention_mask, score

    def get_synonyms(self,word):
        synonyms = []
        for syn in wordnet.synsets(word):
            for lemma in syn.lemmas():
                synonyms.append(lemma.name())
        return list(set(synonyms))

    def get_replacable_word(self,seq,sep = '[SEP]'):
        seq = seq.replace(sep,'')
        seq = re.sub(r"[^a-zA-Z0-9]+", ' ', seq)
        seq = list(set(seq.split()))
        candidates = [cand for cand in seq if self.get_synonyms(cand) != []]

        return random.choice(candidates)

    def apply_augmentation(self,seq):
        word = self.get_replacable_word(seq)
        seq = seq.replace(word,random.choice(self.get_synonyms(word)))
        return seq



class STR_DataModule(pl.LightningDataModule):
    def __init__(self, train_data, batch_size=32, syn_replace = False):
        super().__init__()
        self.batch_size = batch_size
        self.train_data = train_data
        self.syn_replace = syn_replace

    def prepare_data_per_node(self):
    # Return True if you want prepare_data to be called on each node
        return False

    def prepare_data(self):
        train_df, val_df = train_test_split(self.train_data, test_size=0.2, random_state=42)
        self.train_dataset = str_dataset(train_df.reset_index(drop=True),syn_replace = self.syn_replace)
        self.val_dataset = str_dataset(val_df.reset_index(drop=True),syn_replace = False)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size,num_workers = 2, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size,num_workers = 2)

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [None]:
class SentenceSimilarityModel(pl.LightningModule):
    def __init__(self, learning_rate=1.737e-6):
        super(SentenceSimilarityModel, self).__init__()
        self.berta = AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased',ignore_mismatched_sizes=True,num_labels=1)


        self.learning_rate = learning_rate

    def forward(self, input_ids,attention_mask):
        outputs = self.berta(input_ids, attention_mask=attention_mask)

        logits = outputs.logits
        return torch.sigmoid(logits)

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, score = batch
        logits = self(input_ids,attention_mask)


        loss = F.mse_loss(logits, score)
        spearman = spearmanr(logits.detach().cpu().numpy(),score.detach().cpu().numpy()).statistic

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("train_spearman", spearman,on_step = False, on_epoch=True, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, score = batch
        logits = self(input_ids, attention_mask)


        loss = F.mse_loss(logits, score)

        spearman = spearmanr(logits.detach().cpu().numpy(),score.detach().cpu().numpy()).statistic

        self.log("val_loss", loss,on_epoch = True, prog_bar = True)
        self.log("val_spearman",spearman,on_epoch = True, prog_bar = True)


    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, min_lr=1e-8)

        return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'monitor': 'val_loss'}}



In [None]:
if __name__ == '__main__':

    kf = KFold(n_splits=5, shuffle=True, random_state=42)




    train_data = pd.read_csv('/content/eng_train.csv')

    all_spearman_corrs = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(train_data)):

      str_datamodule = STR_DataModule(train_data = train_data.iloc[train_idx], batch_size = 32, syn_replace = True)

      model = SentenceSimilarityModel()

      trainer = pl.Trainer(max_epochs=30,precision="16-mixed",callbacks=[EarlyStopping(monitor="val_loss", patience=3, mode="min")],accelerator='gpu')

      # tuner = pl.tuner.Tuner(trainer)
      # lr_finder = tuner.lr_find(model, datamodule = str_datamodule, min_lr = 1e-7,max_lr= 1e-2)

      # # Pick point based on plot, or get suggestion
      # new_lr = lr_finder.suggestion()

      # #  update hparams of the model
      # model.lr = new_lr

      trainer.fit(model, datamodule=str_datamodule)

      val_set = str_dataset(train_data.iloc[val_idx])

      val_dataloader = DataLoader(val_set, batch_size = 16, num_workers = 2)


      trainer.test(model, val_dataloader)


      # Calculate and print the average Spearman correlation
      average_spearman_corr = trainer.callback_metrics["val_spearman"].mean()
      print(
          f"Average Spearman Correlation for Fold {fold + 1}: {average_spearman_corr}"
      )

      all_spearman_corrs.append(average_spearman_corr)


    # Calculate and print the overall average Spearman correlation
    overall_average_spearman_corr = sum(all_spearman_corrs) / len(all_spearman_corrs)
    print(
        f"Overall Average Spearman Correlation across all folds: {overall_average_spearman_corr}"
    )


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'pre_classifier.bias', 'classifier.weight', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                                | Params
------------------------------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]