In [2]:

import pandas as pd
from tqdm import tqdm
from sentence_transformers import InputExample, models
from torch import nn
import torch.optim as optim
import torch


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Using an already pretrained SBERT to compute cosine similarities which would then be used 

df=pd.read_csv("stsb_train.csv")
df=df.drop("idx",axis=1)
df_cosine_labels = df.copy() 
df_cosine_labels["label"] = df_cosine_labels["label"] /df_cosine_labels["label"].abs().max() 


df=pd.read_csv("stsb_validation.csv")
df=df.drop("idx",axis=1)
validation_cosine_labels = df.copy() 
validation_cosine_labels["label"] = validation_cosine_labels["label"] /validation_cosine_labels["label"].abs().max() 



In [4]:

from torch.utils.data import DataLoader
train_examples = []
for i,row in df_cosine_labels.iterrows():
    train_examples.append([[row["sentence1"], row["sentence2"]],row["label"]])

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

val_examples = []
for i,row in validation_cosine_labels.iterrows():
    val_examples.append([[row["sentence1"], row["sentence2"]],row["label"]])

val_dataloader = DataLoader(val_examples, shuffle=True, batch_size=16)

In [17]:
from transformers import DistilBertModel,DistilBertTokenizer
class ShroomBasHoGayaBhai(nn.Module):
    def __init__(self,max_seq_length):
        super(ShroomBasHoGayaBhai, self).__init__()
        self.word_embedding_model = models.Transformer("distilbert-base-uncased", max_seq_length=max_seq_length)
        self.pooling_model = models.Pooling(self.word_embedding_model.get_word_embedding_dimension())
        self.dense_model = models.Dense(
            in_features=self.pooling_model.get_sentence_embedding_dimension(),
            out_features=256,
            activation_function=nn.Tanh(),
        )
        self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)

    def forward_once(self, x):
        embedding = self.word_embedding_model(x)
        embedding = self.pooling_model(embedding)
        embedding = self.dense_model(embedding)
 

        return embedding

    def forward(self, sentence1, sentence2):
        output1 = self.forward_once(sentence1)["sentence_embedding"]
        output2 = self.forward_once(sentence2)["sentence_embedding"]
        similarity_score = self.cos(output1, output2)
  
        return similarity_score

def encode_pair(pair,padding=True,truncation=True,max_length=256):
    tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
    return tokenizer(
    pair[0], padding=padding,truncation=truncation,max_length=max_length, return_tensors="pt"
    ), tokenizer(
        pair[1],padding=padding,truncation=truncation,max_length=max_length, return_tensors="pt"
    )

model = ShroomBasHoGayaBhai(256)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0000001)
num_epochs=50
best_val_loss = float('inf')
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total parameters in the model:", total_params)
for epoch in tqdm(range(num_epochs)):
    running_loss = 0.0
    correct = 0
    total = 0
    model.train()
    for i, data in enumerate(train_dataloader, 0):
        if i%100==0:
            print(f"Batch: {i+1}")
        sentences, labels = data
        optimizer.zero_grad()
        sentence1,sentence2=encode_pair(sentences)
        labels = labels.float()
        outputs=model(sentence1,sentence2)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for val_data in val_dataloader:
            val_sentences, val_labels = val_data
            val_sentence1, val_sentence2 = encode_pair(val_sentences)
            val_labels = val_labels.float()
            val_outputs = model(val_sentence1, val_sentence2)
            val_loss += criterion(val_outputs, val_labels).item()
    val_loss /= len(val_dataloader)

    print(f'Epoch {epoch+1} | Train Loss: {running_loss / len(train_dataloader)} | Validation Loss: {val_loss}')

    # if val_loss < best_val_loss:
    #     torch.save(model.state_dict(), "best_model_v2_lr1.pth")
    #     best_val_loss = val_loss


Total parameters in the model: 66559744


  0%|          | 0/50 [00:00<?, ?it/s]

Batch: 1
Batch: 101
Batch: 201
Batch: 301


  2%|▏         | 1/50 [03:59<3:15:38, 239.56s/it]

Epoch 1 | Train Loss: 0.17361739428920878 | Validation Loss: 0.18584504137013821
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  4%|▍         | 2/50 [07:53<3:09:06, 236.39s/it]

Epoch 2 | Train Loss: 0.14530529578526816 | Validation Loss: 0.15166122204762825
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  6%|▌         | 3/50 [11:45<3:03:37, 234.42s/it]

Epoch 3 | Train Loss: 0.11903340433620745 | Validation Loss: 0.1229275793946804
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  8%|▊         | 4/50 [16:09<3:08:33, 245.94s/it]

Epoch 4 | Train Loss: 0.09811013885256317 | Validation Loss: 0.10136952005485271
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 10%|█         | 5/50 [20:21<3:06:13, 248.30s/it]

Epoch 5 | Train Loss: 0.08306165556423366 | Validation Loss: 0.08623752489368966
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 12%|█▏        | 6/50 [24:35<3:03:26, 250.16s/it]

Epoch 6 | Train Loss: 0.07224878998887208 | Validation Loss: 0.07522154813434215
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 14%|█▍        | 7/50 [28:49<3:00:14, 251.50s/it]

Epoch 7 | Train Loss: 0.06426240238361061 | Validation Loss: 0.06751771090908888
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 16%|█▌        | 8/50 [33:10<2:57:57, 254.24s/it]

Epoch 8 | Train Loss: 0.05925389516891705 | Validation Loss: 0.06198900301960555
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 18%|█▊        | 9/50 [37:18<2:52:31, 252.48s/it]

Epoch 9 | Train Loss: 0.054803589296837645 | Validation Loss: 0.05756623395025096
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 20%|██        | 10/50 [41:24<2:47:00, 250.52s/it]

Epoch 10 | Train Loss: 0.05140574430115521 | Validation Loss: 0.05427072722306277
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 22%|██▏       | 11/50 [45:30<2:41:51, 249.02s/it]

Epoch 11 | Train Loss: 0.04849770855055087 | Validation Loss: 0.05167022688274688
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 24%|██▍       | 12/50 [49:30<2:35:58, 246.27s/it]

Epoch 12 | Train Loss: 0.04664365247849168 | Validation Loss: 0.049391815627112666
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 26%|██▌       | 13/50 [53:22<2:29:18, 242.11s/it]

Epoch 13 | Train Loss: 0.04495269546750933 | Validation Loss: 0.04762057642987434
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 28%|██▊       | 14/50 [57:14<2:23:18, 238.84s/it]

Epoch 14 | Train Loss: 0.04348622896812028 | Validation Loss: 0.04607306335596962
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 30%|███       | 15/50 [1:01:16<2:19:57, 239.92s/it]

Epoch 15 | Train Loss: 0.04220656094969147 | Validation Loss: 0.044698333352843816
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 32%|███▏      | 16/50 [1:05:24<2:17:16, 242.25s/it]

Epoch 16 | Train Loss: 0.04081043775141653 | Validation Loss: 0.04360899057714863
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 34%|███▍      | 17/50 [1:09:38<2:15:11, 245.80s/it]

Epoch 17 | Train Loss: 0.04011172445801397 | Validation Loss: 0.042633578220897535
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 36%|███▌      | 18/50 [1:13:49<2:11:57, 247.43s/it]

Epoch 18 | Train Loss: 0.039270955875205495 | Validation Loss: 0.04180851152681924
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 38%|███▊      | 19/50 [1:17:53<2:07:20, 246.46s/it]

Epoch 19 | Train Loss: 0.03836803052884837 | Validation Loss: 0.04097199055584187
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 40%|████      | 20/50 [1:21:50<2:01:43, 243.46s/it]

Epoch 20 | Train Loss: 0.03751841623646517 | Validation Loss: 0.04039845757305305
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 42%|████▏     | 21/50 [1:25:51<1:57:18, 242.70s/it]

Epoch 21 | Train Loss: 0.03681966278753761 | Validation Loss: 0.039685947186452276
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 44%|████▍     | 22/50 [1:29:49<1:52:39, 241.41s/it]

Epoch 22 | Train Loss: 0.03644095077438073 | Validation Loss: 0.03903538998256021
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 46%|████▌     | 23/50 [1:33:51<1:48:40, 241.51s/it]

Epoch 23 | Train Loss: 0.03579940949825363 | Validation Loss: 0.03860295956280637
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 48%|████▊     | 24/50 [1:37:46<1:43:48, 239.54s/it]

Epoch 24 | Train Loss: 0.035318089281726216 | Validation Loss: 0.038065280775202714
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 50%|█████     | 25/50 [1:41:44<1:39:37, 239.10s/it]

Epoch 25 | Train Loss: 0.03463039898003141 | Validation Loss: 0.03765453760849034
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 52%|█████▏    | 26/50 [1:45:35<1:34:44, 236.83s/it]

Epoch 26 | Train Loss: 0.03413296064698241 | Validation Loss: 0.03722696004316528
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 54%|█████▍    | 27/50 [1:49:27<1:30:13, 235.39s/it]

Epoch 27 | Train Loss: 0.03373740883544087 | Validation Loss: 0.03687221830354092
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 56%|█████▌    | 28/50 [1:53:26<1:26:40, 236.37s/it]

Epoch 28 | Train Loss: 0.03309145336323935 | Validation Loss: 0.03648334029229715
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 58%|█████▊    | 29/50 [1:57:25<1:22:58, 237.07s/it]

Epoch 29 | Train Loss: 0.032990467575533935 | Validation Loss: 0.036144390583355376
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 60%|██████    | 30/50 [2:01:22<1:19:02, 237.11s/it]

Epoch 30 | Train Loss: 0.03204299850234141 | Validation Loss: 0.03588741225131015
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 62%|██████▏   | 31/50 [2:05:19<1:15:04, 237.06s/it]

Epoch 31 | Train Loss: 0.03200841712661916 | Validation Loss: 0.035437613279816316
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 64%|██████▍   | 32/50 [2:09:22<1:11:40, 238.90s/it]

Epoch 32 | Train Loss: 0.03173059131044687 | Validation Loss: 0.0352154822247301
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 66%|██████▌   | 33/50 [2:13:27<1:08:11, 240.67s/it]

Epoch 33 | Train Loss: 0.03166051355397536 | Validation Loss: 0.03491893503814936
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 68%|██████▊   | 34/50 [2:17:24<1:03:53, 239.59s/it]

Epoch 34 | Train Loss: 0.030927557940594853 | Validation Loss: 0.034668735149217415
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 70%|███████   | 35/50 [2:21:23<59:49, 239.31s/it]  

Epoch 35 | Train Loss: 0.030635444699631382 | Validation Loss: 0.034437001851248615
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 72%|███████▏  | 36/50 [2:25:18<55:32, 238.07s/it]

Epoch 36 | Train Loss: 0.030346881296847844 | Validation Loss: 0.03412337486255676
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 74%|███████▍  | 37/50 [2:29:17<51:38, 238.36s/it]

Epoch 37 | Train Loss: 0.029961400375598007 | Validation Loss: 0.03394511160737974
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 76%|███████▌  | 38/50 [2:33:24<48:13, 241.11s/it]

Epoch 38 | Train Loss: 0.02957251178773327 | Validation Loss: 0.033688071847675326
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 78%|███████▊  | 39/50 [2:37:25<44:12, 241.12s/it]

Epoch 39 | Train Loss: 0.029442025480481485 | Validation Loss: 0.03357123950139639
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 80%|████████  | 40/50 [2:41:38<40:45, 244.55s/it]

Epoch 40 | Train Loss: 0.029077024807015225 | Validation Loss: 0.03329619012297468
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 82%|████████▏ | 41/50 [2:45:38<36:28, 243.14s/it]

Epoch 41 | Train Loss: 0.028573510120622814 | Validation Loss: 0.03311376530241142
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 84%|████████▍ | 42/50 [2:49:41<32:25, 243.17s/it]

Epoch 42 | Train Loss: 0.02844842243551587 | Validation Loss: 0.03292876949652712
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 86%|████████▌ | 43/50 [2:53:41<28:15, 242.20s/it]

Epoch 43 | Train Loss: 0.028345626299010798 | Validation Loss: 0.032745456877858084
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 88%|████████▊ | 44/50 [2:57:39<24:06, 241.05s/it]

Epoch 44 | Train Loss: 0.02806067076873862 | Validation Loss: 0.03258242082108367
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 90%|█████████ | 45/50 [3:01:45<20:11, 242.37s/it]

Epoch 45 | Train Loss: 0.027546941812357142 | Validation Loss: 0.032476947454616746
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 92%|█████████▏| 46/50 [3:05:49<16:11, 242.92s/it]

Epoch 46 | Train Loss: 0.027647043166992565 | Validation Loss: 0.032302697268413734
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 94%|█████████▍| 47/50 [3:09:47<12:04, 241.41s/it]

Epoch 47 | Train Loss: 0.027171155265791135 | Validation Loss: 0.03214357023503869
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 96%|█████████▌| 48/50 [3:13:45<08:00, 240.44s/it]

Epoch 48 | Train Loss: 0.027106327426412866 | Validation Loss: 0.0320096748405473
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 98%|█████████▊| 49/50 [3:17:43<03:59, 239.61s/it]

Epoch 49 | Train Loss: 0.02661641671632727 | Validation Loss: 0.03185276660394478
Batch: 1
Batch: 101
Batch: 201
Batch: 301


100%|██████████| 50/50 [3:21:52<00:00, 242.25s/it]

Epoch 50 | Train Loss: 0.026571826328937377 | Validation Loss: 0.03165640120216189





NameError: name 'encode_pair' is not defined

In [9]:

state_dict = torch.load("best_model_v2_lr1.pth")
model= ShroomBasHoGayaBhai(256)


model.load_state_dict(state_dict)



OrderedDict([('word_embedding_model.auto_model.embeddings.word_embeddings.weight', tensor([[-0.0166, -0.0666, -0.0163,  ..., -0.0200, -0.0514, -0.0264],
        [-0.0132, -0.0673, -0.0161,  ..., -0.0227, -0.0554, -0.0260],
        [-0.0176, -0.0709, -0.0144,  ..., -0.0246, -0.0596, -0.0232],
        ...,
        [-0.0231, -0.0588, -0.0105,  ..., -0.0195, -0.0262, -0.0212],
        [-0.0490, -0.0561, -0.0047,  ..., -0.0107, -0.0180, -0.0219],
        [-0.0065, -0.0915, -0.0025,  ..., -0.0151, -0.0504,  0.0460]])), ('word_embedding_model.auto_model.embeddings.position_embeddings.weight', tensor([[ 0.0180, -0.0245, -0.0364,  ...,  0.0004,  0.0003,  0.0153],
        [ 0.0079,  0.0021, -0.0181,  ...,  0.0294,  0.0300, -0.0047],
        [-0.0112, -0.0019, -0.0111,  ...,  0.0164,  0.0189, -0.0080],
        ...,
        [ 0.0174,  0.0035, -0.0096,  ...,  0.0030,  0.0004, -0.0269],
        [ 0.0217, -0.0060,  0.0147,  ..., -0.0056, -0.0126, -0.0281],
        [ 0.0026, -0.0233,  0.0055,  ...,  0

<All keys matched successfully>

In [14]:
sentence1,sentence2=encode_pair([
"A person who cobbles .","Nonsense ."
])
outputs=model(sentence1,sentence2)
print(outputs)

tensor([0.2301], grad_fn=<SumBackward1>)
