In [74]:
import torch.nn as nn
import torch.optim as optim
import torch.nn as nn
from sentence_transformers import models
from transformers import AutoTokenizer
from transformers import BertTokenizer, BertModel
import pandas as pd
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader

In [98]:
# Using an already pretrained SBERT to compute cosine similarities which would then be used 

df=pd.read_csv("stsb_train.csv")
df=df.drop("idx",axis=1)
df_cosine_labels = df.copy() 
df_cosine_labels["label"] = df_cosine_labels["label"] /df_cosine_labels["label"].abs().max() 


df=pd.read_csv("stsb_validation.csv")
df=df.drop("idx",axis=1)
validation_cosine_labels = df.copy() 
validation_cosine_labels["label"] = validation_cosine_labels["label"] /validation_cosine_labels["label"].abs().max() 



In [99]:
train_examples = []
for i,row in df_cosine_labels.iterrows():
    train_examples.append([[row["sentence1"], row["sentence2"]],row["label"]])

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

val_examples = []
for i,row in validation_cosine_labels.iterrows():
    val_examples.append([[row["sentence1"], row["sentence2"]],row["label"]])

val_dataloader = DataLoader(val_examples, shuffle=True, batch_size=16)

In [46]:

class Shroomformer(nn.Module):
    def __init__(self, sequence_length):
        super(Shroomformer, self).__init__()
        self.sequence_length = sequence_length
        self.word_embedding_model = models.Transformer("bert-base-uncased", max_seq_length=self.sequence_length)
        self.pooling_model = models.Pooling(self.word_embedding_model.get_word_embedding_dimension())
        self.twin1 = nn.Sequential(
            self.word_embedding_model,
            self.pooling_model
        )
        self.twin2 = nn.Sequential(
            self.word_embedding_model,
            self.pooling_model
        )
        self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
           
    def forward(self, sentence1_encoded, sentence2_encoded):
        output1 = self.twin1(sentence1_encoded)["sentence_embedding"]
        output2 = self.twin2(sentence2_encoded)["sentence_embedding"]
        similarity_score = self.cos(output1, output2)
        return similarity_score


def encode_pair(pair,padding=True,truncation=True,max_length=128):
    tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
    return tokenizer(
    pair[0], padding=padding,truncation=truncation,max_length=max_length, return_tensors="pt"
    ), tokenizer(
        pair[1],padding=padding,truncation=truncation,max_length=max_length, return_tensors="pt"
    )


model = Shroomformer(128)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs=500

from tqdm import tqdm
for epoch in tqdm(range(num_epochs)):
    running_loss = 0.0
    correct = 0
    total = 0

    for i, data in enumerate(train_dataloader, 0):
        if i%100==0:
            print(f"Batch: {i+1}")
        sentences, labels = data
        optimizer.zero_grad()
        sentence1,sentence2=encode_pair(sentences)
        outputs = model(sentence1,sentence2)
        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print()
    print(f'Epoch {epoch+1} | Loss: {running_loss / len(train_dataloader)}')




  0%|          | 0/500 [00:00<?, ?it/s]

Batch: 1
Batch: 101
Batch: 201
Batch: 301


  0%|          | 1/500 [06:04<50:28:23, 364.13s/it]


Epoch 1 | Loss: 0.2966899295647939
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  0%|          | 2/500 [11:48<48:46:37, 352.61s/it]


Epoch 2 | Loss: 0.2975533603794045
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  1%|          | 3/500 [17:37<48:26:52, 350.93s/it]


Epoch 3 | Loss: 0.2977091911352343
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  1%|          | 4/500 [23:04<47:03:52, 341.60s/it]


Epoch 4 | Loss: 0.29739760547462435
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  1%|          | 5/500 [28:40<46:39:41, 339.36s/it]


Epoch 5 | Loss: 0.2972705681497852
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  1%|          | 6/500 [34:35<47:17:45, 344.67s/it]


Epoch 6 | Loss: 0.29716408494859936
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  1%|▏         | 7/500 [40:26<47:29:54, 346.84s/it]


Epoch 7 | Loss: 0.2976262398478058
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  2%|▏         | 8/500 [46:15<47:30:25, 347.61s/it]


Epoch 8 | Loss: 0.2974280819710758
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  2%|▏         | 9/500 [52:02<47:23:09, 347.43s/it]


Epoch 9 | Loss: 0.2969092881307006
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  2%|▏         | 10/500 [57:49<47:14:15, 347.05s/it]


Epoch 10 | Loss: 0.2974702525056071
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  2%|▏         | 11/500 [1:03:40<47:18:23, 348.27s/it]


Epoch 11 | Loss: 0.29673791223516066
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  2%|▏         | 12/500 [1:09:30<47:18:27, 348.99s/it]


Epoch 12 | Loss: 0.2974832391573323
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  3%|▎         | 13/500 [1:15:19<47:12:11, 348.93s/it]


Epoch 13 | Loss: 0.29691902266608344
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  3%|▎         | 14/500 [1:21:09<47:08:17, 349.17s/it]


Epoch 14 | Loss: 0.2970412415348821
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  3%|▎         | 15/500 [1:26:54<46:52:05, 347.89s/it]


Epoch 15 | Loss: 0.29676136624895866
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  3%|▎         | 16/500 [1:32:35<46:30:25, 345.92s/it]


Epoch 16 | Loss: 0.2971311234351661
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  3%|▎         | 17/500 [1:37:56<45:24:25, 338.44s/it]


Epoch 17 | Loss: 0.2970363655231065
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  4%|▎         | 18/500 [1:43:18<44:38:06, 333.37s/it]


Epoch 18 | Loss: 0.2973413092808591
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  4%|▍         | 19/500 [1:48:37<43:58:45, 329.16s/it]


Epoch 19 | Loss: 0.2971671583958798
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  4%|▍         | 20/500 [1:54:02<43:43:19, 327.91s/it]


Epoch 20 | Loss: 0.2974299171732532
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  4%|▍         | 21/500 [1:59:30<43:38:33, 328.00s/it]


Epoch 21 | Loss: 0.2970357896553146
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  4%|▍         | 22/500 [2:04:55<43:25:22, 327.03s/it]


Epoch 22 | Loss: 0.2976896764089664
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  5%|▍         | 23/500 [2:10:27<43:31:37, 328.51s/it]


Epoch 23 | Loss: 0.29718892876472736
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  5%|▍         | 24/500 [2:15:58<43:31:14, 329.15s/it]


Epoch 24 | Loss: 0.2974106328354941
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  5%|▌         | 25/500 [2:21:31<43:35:34, 330.39s/it]


Epoch 25 | Loss: 0.2975405416968796
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  5%|▌         | 26/500 [2:26:46<42:52:49, 325.67s/it]


Epoch 26 | Loss: 0.2973470267529289
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  5%|▌         | 27/500 [2:32:02<42:26:40, 323.05s/it]


Epoch 27 | Loss: 0.29675162564963103
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  6%|▌         | 28/500 [2:37:23<42:14:29, 322.18s/it]


Epoch 28 | Loss: 0.2971121369757586
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  6%|▌         | 29/500 [2:42:40<41:57:20, 320.68s/it]


Epoch 29 | Loss: 0.29738076244377426
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  6%|▌         | 30/500 [2:47:59<41:49:11, 320.32s/it]


Epoch 30 | Loss: 0.29714426654908394
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  6%|▌         | 31/500 [2:53:19<41:42:35, 320.16s/it]


Epoch 31 | Loss: 0.29718363787978885
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  6%|▋         | 32/500 [2:58:39<41:37:46, 320.23s/it]


Epoch 32 | Loss: 0.29710454253686797
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  7%|▋         | 33/500 [3:04:05<41:44:51, 321.82s/it]


Epoch 33 | Loss: 0.2969538603598873
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  7%|▋         | 34/500 [3:09:29<41:44:48, 322.51s/it]


Epoch 34 | Loss: 0.2974938226449821
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  7%|▋         | 35/500 [3:14:57<41:52:34, 324.20s/it]


Epoch 35 | Loss: 0.29713720358494256
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  7%|▋         | 36/500 [3:20:23<41:50:25, 324.62s/it]


Epoch 36 | Loss: 0.2972820875959264
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  7%|▋         | 37/500 [3:25:47<41:44:14, 324.52s/it]


Epoch 37 | Loss: 0.29718119137816956
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  8%|▊         | 38/500 [3:31:19<41:55:56, 326.75s/it]


Epoch 38 | Loss: 0.29689214101268185
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  8%|▊         | 39/500 [3:36:45<41:48:21, 326.47s/it]


Epoch 39 | Loss: 0.29693003474838203
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  8%|▊         | 40/500 [3:42:14<41:48:53, 327.25s/it]


Epoch 40 | Loss: 0.2967686936880151
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  8%|▊         | 41/500 [3:47:49<42:01:08, 329.56s/it]


Epoch 41 | Loss: 0.2971170264813635
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  8%|▊         | 42/500 [3:53:43<42:52:40, 337.03s/it]


Epoch 42 | Loss: 0.2969654782778687
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  9%|▊         | 43/500 [3:59:29<43:07:36, 339.73s/it]


Epoch 43 | Loss: 0.29787480580723946
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  9%|▉         | 44/500 [4:05:14<43:14:02, 341.32s/it]


Epoch 44 | Loss: 0.29748491723504333
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  9%|▉         | 45/500 [4:11:00<43:17:32, 342.53s/it]


Epoch 45 | Loss: 0.2970293187225858
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  9%|▉         | 46/500 [4:16:46<43:20:33, 343.69s/it]


Epoch 46 | Loss: 0.29748185980651115
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  9%|▉         | 47/500 [4:22:14<42:39:30, 339.01s/it]


Epoch 47 | Loss: 0.29746189024299385
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 10%|▉         | 48/500 [4:27:34<41:49:57, 333.18s/it]


Epoch 48 | Loss: 0.2969899263025986
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 10%|▉         | 49/500 [4:32:54<41:14:12, 329.16s/it]


Epoch 49 | Loss: 0.29748003056479827
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 10%|█         | 50/500 [4:38:10<40:39:38, 325.29s/it]


Epoch 50 | Loss: 0.29711031826833884
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 10%|█         | 51/500 [4:43:29<40:20:08, 323.40s/it]


Epoch 51 | Loss: 0.2974265945661399
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 10%|█         | 52/500 [4:48:55<40:20:12, 324.14s/it]


Epoch 52 | Loss: 0.29701069467183616
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 11%|█         | 53/500 [4:54:15<40:07:12, 323.11s/it]


Epoch 53 | Loss: 0.29725976379381286
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 11%|█         | 54/500 [4:59:36<39:55:38, 322.28s/it]


Epoch 54 | Loss: 0.29717814293172623
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 11%|█         | 55/500 [5:04:51<39:34:27, 320.15s/it]


Epoch 55 | Loss: 0.29716652360641294
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 11%|█         | 56/500 [5:10:08<39:21:13, 319.08s/it]


Epoch 56 | Loss: 0.2973021920770407
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 11%|█▏        | 57/500 [5:15:24<39:10:47, 318.39s/it]


Epoch 57 | Loss: 0.2971580664730734
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 12%|█▏        | 58/500 [5:20:46<39:11:56, 319.27s/it]


Epoch 58 | Loss: 0.29714636359777713
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 12%|█▏        | 59/500 [5:26:03<39:02:35, 318.72s/it]


Epoch 59 | Loss: 0.296834854553971
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 12%|█▏        | 60/500 [5:31:22<38:57:53, 318.80s/it]


Epoch 60 | Loss: 0.2974983539018366
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 12%|█▏        | 61/500 [5:36:45<39:01:02, 319.96s/it]


Epoch 61 | Loss: 0.2975393009475536
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 12%|█▏        | 62/500 [5:42:05<38:55:53, 319.98s/it]


Epoch 62 | Loss: 0.2968762532497446
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 13%|█▎        | 63/500 [5:47:24<38:49:26, 319.83s/it]


Epoch 63 | Loss: 0.29699725330703786
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 13%|█▎        | 64/500 [5:52:44<38:42:46, 319.65s/it]


Epoch 64 | Loss: 0.29715224814911684
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 13%|█▎        | 65/500 [5:57:56<38:22:15, 317.55s/it]


Epoch 65 | Loss: 0.29711214401241803
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 13%|█▎        | 66/500 [6:03:08<38:04:08, 315.78s/it]


Epoch 66 | Loss: 0.2970290288536085
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 13%|█▎        | 67/500 [6:08:21<37:54:12, 315.13s/it]


Epoch 67 | Loss: 0.29748430608047377
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 14%|█▎        | 68/500 [6:13:35<37:46:35, 314.80s/it]


Epoch 68 | Loss: 0.2969165914174583
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 14%|█▍        | 69/500 [6:18:52<37:45:58, 315.45s/it]


Epoch 69 | Loss: 0.29782836126784484
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 14%|█▍        | 70/500 [6:24:09<37:42:31, 315.70s/it]


Epoch 70 | Loss: 0.29695387077000407
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 14%|█▍        | 71/500 [6:29:42<38:15:41, 321.08s/it]


Epoch 71 | Loss: 0.29725513915634816
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 14%|█▍        | 72/500 [6:35:36<39:20:52, 330.96s/it]


Epoch 72 | Loss: 0.29710784831808673
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 15%|█▍        | 73/500 [6:41:05<39:09:38, 330.16s/it]


Epoch 73 | Loss: 0.2973006478200356
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 15%|█▍        | 74/500 [6:46:53<39:42:37, 335.58s/it]


Epoch 74 | Loss: 0.2971371962585383
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 15%|█▌        | 75/500 [6:52:38<39:57:06, 338.42s/it]


Epoch 75 | Loss: 0.296810245865749
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 15%|█▌        | 76/500 [6:58:02<39:20:15, 334.00s/it]


Epoch 76 | Loss: 0.2969605765822861
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 15%|█▌        | 77/500 [7:03:30<39:03:42, 332.44s/it]


Epoch 77 | Loss: 0.2968646371116241
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 16%|█▌        | 78/500 [7:08:51<38:32:18, 328.76s/it]


Epoch 78 | Loss: 0.29704369532151353
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 16%|█▌        | 79/500 [7:14:25<38:39:14, 330.53s/it]


Epoch 79 | Loss: 0.2973608639505174
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 16%|█▌        | 80/500 [7:19:43<38:07:40, 326.81s/it]


Epoch 80 | Loss: 0.29749229177832603
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 16%|█▌        | 81/500 [7:25:09<38:00:13, 326.52s/it]


Epoch 81 | Loss: 0.29712497947944533
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 16%|█▋        | 82/500 [7:30:52<38:28:45, 331.40s/it]


Epoch 82 | Loss: 0.2971757040669521
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 17%|█▋        | 83/500 [7:36:23<38:21:17, 331.12s/it]


Epoch 83 | Loss: 0.29771653250273733
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 17%|█▋        | 84/500 [7:42:05<38:39:17, 334.51s/it]


Epoch 84 | Loss: 0.2972221368095941
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 17%|█▋        | 85/500 [7:47:45<38:45:39, 336.24s/it]


Epoch 85 | Loss: 0.29715370051562784
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 17%|█▋        | 86/500 [7:53:27<38:51:34, 337.91s/it]


Epoch 86 | Loss: 0.29705836671508024
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 17%|█▋        | 87/500 [7:58:53<38:20:56, 334.28s/it]


Epoch 87 | Loss: 0.29713726883961095
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 18%|█▊        | 88/500 [8:04:26<38:12:41, 333.89s/it]


Epoch 88 | Loss: 0.2970279625099566
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 18%|█▊        | 89/500 [8:09:52<37:50:39, 331.48s/it]


Epoch 89 | Loss: 0.29722764500313337
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 18%|█▊        | 90/500 [8:15:17<37:32:42, 329.66s/it]


Epoch 90 | Loss: 0.29716714463300176
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 18%|█▊        | 91/500 [8:20:46<37:25:18, 329.39s/it]


Epoch 91 | Loss: 0.29705916552080047
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 18%|█▊        | 92/500 [8:26:11<37:10:50, 328.06s/it]


Epoch 92 | Loss: 0.2972056430040134
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 19%|█▊        | 93/500 [8:31:34<36:55:29, 326.61s/it]


Epoch 93 | Loss: 0.29717295792781645
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 19%|█▉        | 94/500 [8:36:54<36:35:45, 324.50s/it]


Epoch 94 | Loss: 0.297664586123493
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 19%|█▉        | 95/500 [8:42:25<36:43:45, 326.48s/it]


Epoch 95 | Loss: 0.29705347634024093
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 19%|█▉        | 96/500 [8:47:59<36:53:12, 328.69s/it]


Epoch 96 | Loss: 0.29706396922055217
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 19%|█▉        | 97/500 [8:53:18<36:28:41, 325.86s/it]


Epoch 97 | Loss: 0.297199876109759
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 20%|█▉        | 98/500 [8:58:40<36:16:02, 324.78s/it]


Epoch 98 | Loss: 0.2968664745282796
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 20%|█▉        | 99/500 [9:04:01<36:03:39, 323.74s/it]


Epoch 99 | Loss: 0.2968729304977589
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 20%|██        | 100/500 [9:09:21<35:49:59, 322.50s/it]


Epoch 100 | Loss: 0.29754847710331284
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 20%|██        | 101/500 [9:14:38<35:33:02, 320.76s/it]


Epoch 101 | Loss: 0.29719463781350186
Batch: 1


 20%|██        | 101/500 [9:15:00<36:32:33, 329.71s/it]


KeyboardInterrupt: 

In [47]:
torch.save(model.state_dict(), 'model_v1.pth')

In [84]:

class ShroomformerV2(nn.Module):
    def __init__(self):
        super(ShroomformerV2, self).__init__()
        self.twin1 = BertModel.from_pretrained("bert-base-uncased")
        self.twin2 = BertModel.from_pretrained("bert-base-uncased")
        self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    
    def mean_pooling(self,word_embeddings, attention_mask):
        token_embeddings = word_embeddings[0] 
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

  
    def forward(self, sentence1_encoded, sentence2_encoded):
        word_embeddings_1 = self.twin1(**sentence1_encoded)
        sentence_embeddings_1=self.mean_pooling(word_embeddings_1, sentence1_encoded["attention_mask"])
        word_embeddings_2 = self.twin2(**sentence2_encoded)
        sentence_embeddings_2=self.mean_pooling(word_embeddings_2, sentence2_encoded["attention_mask"])
        similarity_score = self.cos(sentence_embeddings_1,sentence_embeddings_2)
        return similarity_score
       


def encode_pair(pair,padding=True,truncation=True,max_length=128):
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    return tokenizer(
    pair[0], padding=padding,truncation=truncation,max_length=max_length, return_tensors="pt"
    ), tokenizer(
        pair[1],padding=padding,truncation=truncation,max_length=max_length, return_tensors="pt"
    )


model = ShroomformerV2()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs=100
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total parameters in the model:", total_params)
model.train()
from tqdm import tqdm
for epoch in tqdm(range(num_epochs)):
    running_loss = 0.0
    correct = 0
    total = 0

    for i, data in enumerate(train_dataloader, 0):
        if i%100==0:
            print(f"Batch: {i+1}")
        sentences, labels = data
        optimizer.zero_grad()
        sentence1,sentence2=encode_pair(sentences)
        outputs=model(sentence1,sentence2)
        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print()
    print(f'Epoch {epoch+1} | Loss: {running_loss / len(train_dataloader)}')




Total parameters in the model: 218964480


  0%|          | 0/100 [00:00<?, ?it/s]

Batch: 1
Batch: 101
Batch: 201
Batch: 301


  1%|          | 1/100 [06:58<11:30:45, 418.64s/it]


Epoch 1 | Loss: 0.09140385061295496
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  2%|▏         | 2/100 [13:58<11:24:38, 419.16s/it]


Epoch 2 | Loss: 0.08696048661756019
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  3%|▎         | 3/100 [20:59<11:19:13, 420.14s/it]


Epoch 3 | Loss: 0.0866039344161335
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  4%|▍         | 4/100 [28:57<11:48:30, 442.82s/it]


Epoch 4 | Loss: 0.08692328405773474
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  5%|▌         | 5/100 [36:34<11:49:23, 448.04s/it]


Epoch 5 | Loss: 0.08699724107152886
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  6%|▌         | 6/100 [45:12<12:19:26, 471.99s/it]


Epoch 6 | Loss: 0.08647807241520948
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  7%|▋         | 7/100 [52:44<12:01:19, 465.37s/it]


Epoch 7 | Loss: 0.08636120182151595
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  8%|▊         | 8/100 [1:00:10<11:44:08, 459.23s/it]


Epoch 8 | Loss: 0.08646756580306424
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  9%|▉         | 9/100 [1:07:48<11:35:39, 458.68s/it]


Epoch 9 | Loss: 0.08627520025604302
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 10%|█         | 10/100 [1:15:10<11:20:28, 453.65s/it]


Epoch 10 | Loss: 0.08648439190453953
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 11%|█         | 11/100 [1:22:24<11:04:06, 447.71s/it]


Epoch 11 | Loss: 0.08641103722362055
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 12%|█▏        | 12/100 [1:29:52<10:56:49, 447.84s/it]


Epoch 12 | Loss: 0.0863054936648243
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 13%|█▎        | 13/100 [1:37:05<10:42:39, 443.21s/it]


Epoch 13 | Loss: 0.08615879628600345
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 14%|█▍        | 14/100 [1:44:12<10:28:21, 438.39s/it]


Epoch 14 | Loss: 0.08623913600006038
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 15%|█▌        | 15/100 [1:51:25<10:18:45, 436.77s/it]


Epoch 15 | Loss: 0.08649165533586509
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 16%|█▌        | 16/100 [1:58:35<10:08:41, 434.78s/it]


Epoch 16 | Loss: 0.08620000312932663
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 17%|█▋        | 17/100 [2:05:51<10:01:44, 434.99s/it]


Epoch 17 | Loss: 0.08592957039881084
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 18%|█▊        | 18/100 [2:13:26<10:02:33, 440.90s/it]


Epoch 18 | Loss: 0.08600735939107836
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 19%|█▉        | 19/100 [2:20:55<9:58:53, 443.62s/it] 


Epoch 19 | Loss: 0.08641308066952559
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 20%|██        | 20/100 [2:28:08<9:47:06, 440.33s/it]


Epoch 20 | Loss: 0.08604305102489888
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 21%|██        | 21/100 [2:35:32<9:41:08, 441.38s/it]


Epoch 21 | Loss: 0.08627502795101868
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 22%|██▏       | 22/100 [2:43:01<9:36:48, 443.70s/it]


Epoch 22 | Loss: 0.08606046617238058
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 23%|██▎       | 23/100 [2:50:20<9:27:36, 442.29s/it]


Epoch 23 | Loss: 0.08600626779306265
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 24%|██▍       | 24/100 [2:57:33<9:16:51, 439.63s/it]


Epoch 24 | Loss: 0.08608038294025594
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 25%|██▌       | 25/100 [3:04:47<9:07:06, 437.69s/it]


Epoch 25 | Loss: 0.08594423026467363
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 26%|██▌       | 26/100 [3:12:06<9:00:20, 438.11s/it]


Epoch 26 | Loss: 0.08601857399464481
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 27%|██▋       | 27/100 [3:19:23<8:52:37, 437.77s/it]


Epoch 27 | Loss: 0.08579579836999376
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 28%|██▊       | 28/100 [3:27:41<9:07:15, 456.04s/it]


Epoch 28 | Loss: 0.08601455138996243
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 29%|██▉       | 29/100 [3:34:51<8:50:15, 448.10s/it]


Epoch 29 | Loss: 0.08579812913926112
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 30%|███       | 30/100 [3:41:19<8:21:35, 429.94s/it]


Epoch 30 | Loss: 0.08584257872878677
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 31%|███       | 31/100 [3:54:47<10:25:03, 543.53s/it]


Epoch 31 | Loss: 0.08598652690028151
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 32%|███▏      | 32/100 [4:04:49<10:35:47, 560.99s/it]


Epoch 32 | Loss: 0.08598137813516789
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 33%|███▎      | 33/100 [4:11:21<9:29:44, 510.22s/it] 


Epoch 33 | Loss: 0.0858574364748266
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 34%|███▍      | 34/100 [4:17:54<8:42:42, 475.19s/it]


Epoch 34 | Loss: 0.08593545713358455
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 35%|███▌      | 35/100 [4:25:02<8:19:17, 460.88s/it]


Epoch 35 | Loss: 0.08628203681566649
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 36%|███▌      | 36/100 [4:32:09<8:00:59, 450.93s/it]


Epoch 36 | Loss: 0.085855706812193
Batch: 1
Batch: 101
Batch: 201


In [15]:
# Check that MPS is available
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")
    mps_device = torch.device("cpu")
    

else:
    mps_device = torch.device("mps")

In [91]:

class ShroomformerV2(nn.Module):
    def __init__(self):
        super(ShroomformerV2, self).__init__()
        self.twin1 = BertModel.from_pretrained("bert-base-uncased")
        self.twin2 = BertModel.from_pretrained("bert-base-uncased")
        self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    
    def mean_pooling(self,word_embeddings, attention_mask):
        token_embeddings = word_embeddings[0] 
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

  
    def forward(self, sentence1_encoded, sentence2_encoded):
        word_embeddings_1 = self.twin1(**sentence1_encoded)
        sentence_embeddings_1=self.mean_pooling(word_embeddings_1, sentence1_encoded["attention_mask"])
        word_embeddings_2 = self.twin2(**sentence2_encoded)
        sentence_embeddings_2=self.mean_pooling(word_embeddings_2, sentence2_encoded["attention_mask"])
        similarity_score = self.cos(sentence_embeddings_1,sentence_embeddings_2)
        return similarity_score
       


def encode_pair(pair, padding=True, truncation=True, max_length=128):
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    encoded_pair_1 = tokenizer(
        pair[0], 
        padding=padding, 
        truncation=truncation, 
        max_length=max_length, 
        return_tensors="pt"
    )
    encoded_pair_2 = tokenizer(
        pair[1],
        padding=padding,
        truncation=truncation,
        max_length=max_length,
        return_tensors="pt"
    )
    

    return encoded_pair_1, encoded_pair_2
model = ShroomformerV2()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.00001)
num_epochs=100
best_val_loss = float('inf')
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total parameters in the model:", total_params)
for epoch in tqdm(range(num_epochs)):
    running_loss = 0.0
    correct = 0
    total = 0
    model.train()
    for i, data in enumerate(train_dataloader, 0):
        if i%100==0:
            print(f"Batch: {i+1}")
        sentences, labels = data
        optimizer.zero_grad()
        sentence1,sentence2=encode_pair(sentences)
        outputs=model(sentence1,sentence2)
        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for val_data in val_dataloader:
            val_sentences, val_labels = val_data
            val_sentence1, val_sentence2 = encode_pair(val_sentences)
            val_outputs = model(val_sentence1, val_sentence2)
            val_loss += criterion(val_outputs, val_labels.float()).item()
    val_loss /= len(val_dataloader)

    print(f'Epoch {epoch+1} | Train Loss: {running_loss / len(train_dataloader)} | Validation Loss: {val_loss}')

    if val_loss < best_val_loss:
        torch.save(model.state_dict(), "best_model_v2_lr1.pth")
        best_val_loss = val_loss


Total parameters in the model: 218964480


  0%|          | 0/100 [00:00<?, ?it/s]

Batch: 1
Batch: 101
Batch: 201
Batch: 301
Epoch 1 | Train Loss: 0.04102387403448423 | Validation Loss: 0.03369533416240456


  1%|          | 1/100 [08:55<14:43:51, 535.67s/it]

Batch: 1
Batch: 101
Batch: 201
Batch: 301
Epoch 2 | Train Loss: 0.020565950364754018 | Validation Loss: 0.029373987875086195


  2%|▏         | 2/100 [17:17<14:02:35, 515.87s/it]

Batch: 1
Batch: 101
Batch: 201
Batch: 301
Epoch 3 | Train Loss: 0.012345999609290932 | Validation Loss: 0.029069310965690206


  3%|▎         | 3/100 [24:57<13:12:57, 490.49s/it]

Batch: 1
Batch: 101
Batch: 201
Batch: 301


  4%|▍         | 4/100 [32:53<12:55:04, 484.42s/it]

Epoch 4 | Train Loss: 0.007954893568724704 | Validation Loss: 0.029124279759150554
Batch: 1
Batch: 101
Batch: 201
Batch: 301
Epoch 5 | Train Loss: 0.005968775502535411 | Validation Loss: 0.028309783888386286


  5%|▌         | 5/100 [42:27<13:38:14, 516.79s/it]

Batch: 1
Batch: 101
Batch: 201
Batch: 301


  6%|▌         | 6/100 [51:07<13:31:43, 518.12s/it]

Epoch 6 | Train Loss: 0.004712820964697231 | Validation Loss: 0.028533446434092648
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  7%|▋         | 7/100 [1:00:14<13:37:43, 527.56s/it]

Epoch 7 | Train Loss: 0.004191748624564045 | Validation Loss: 0.029183450382836956
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  8%|▊         | 8/100 [1:09:13<13:34:28, 531.18s/it]

Epoch 8 | Train Loss: 0.003667591901871169 | Validation Loss: 0.028948412841542603
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  9%|▉         | 9/100 [1:17:58<13:22:36, 529.19s/it]

Epoch 9 | Train Loss: 0.003445363009490797 | Validation Loss: 0.0295445860680589
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 10%|█         | 10/100 [1:26:42<13:11:16, 527.52s/it]

Epoch 10 | Train Loss: 0.0033531822940050106 | Validation Loss: 0.029611960846058865
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 11%|█         | 11/100 [1:35:39<13:06:45, 530.40s/it]

Epoch 11 | Train Loss: 0.0032314999809993123 | Validation Loss: 0.030233834315329156
Batch: 1


 11%|█         | 11/100 [1:36:46<13:02:59, 527.86s/it]


KeyboardInterrupt: 

In [92]:
torch.save(model.state_dict(), "model_v2_lr1.pth")
lrs={"lr1":0.00001}

In [93]:
# Check that MPS is available
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")
    device = torch.device("cpu")
    

else:
    device = torch.device("mps")

In [101]:

from transformers import DistilBertTokenizer, DistilBertModel
class ShroomformerV2(nn.Module):
    def __init__(self):
        super(ShroomformerV2, self).__init__()
        self.twin1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.twin2 = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    
    def mean_pooling(self,word_embeddings, attention_mask):
        token_embeddings = word_embeddings[0] 
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

  
    def forward(self, sentence1_encoded, sentence2_encoded):
        word_embeddings_1 = self.twin1(**sentence1_encoded)
        sentence_embeddings_1=self.mean_pooling(word_embeddings_1, sentence1_encoded["attention_mask"])
        word_embeddings_2 = self.twin2(**sentence2_encoded)
        sentence_embeddings_2=self.mean_pooling(word_embeddings_2, sentence2_encoded["attention_mask"])
        similarity_score = self.cos(sentence_embeddings_1,sentence_embeddings_2)
        return similarity_score
       


def encode_pair(pair, padding=True, truncation=True, max_length=128):
    tokenizer=DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    encoded_pair_1 = tokenizer(
        pair[0], 
        padding=padding, 
        truncation=truncation, 
        max_length=max_length, 
        return_tensors="pt"
    )
    encoded_pair_2 = tokenizer(
        pair[1],
        padding=padding,
        truncation=truncation,
        max_length=max_length,
        return_tensors="pt"
    )
    return encoded_pair_1, encoded_pair_2
model = ShroomformerV2()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.00001)
num_epochs=100
best_val_loss = float('inf')
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total parameters in the model:", total_params)
for epoch in tqdm(range(num_epochs)):
    running_loss = 0.0
    correct = 0
    total = 0
    model.train()
    for i, data in enumerate(train_dataloader, 0):
        if i%100==0:
            print(f"Batch: {i+1}")
        sentences, labels = data
        optimizer.zero_grad()
        sentence1,sentence2=encode_pair(sentences)
        labels = labels.float()
        outputs=model(sentence1,sentence2)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for val_data in val_dataloader:
            val_sentences, val_labels = val_data
            val_sentence1, val_sentence2 = encode_pair(val_sentences)
            val_labels = val_labels.float()
            val_outputs = model(val_sentence1, val_sentence2)
            val_loss += criterion(val_outputs, val_labels).item()
    val_loss /= len(val_dataloader)

    print(f'Epoch {epoch+1} | Train Loss: {running_loss / len(train_dataloader)} | Validation Loss: {val_loss}')

    if val_loss < best_val_loss:
        torch.save(model.state_dict(), "best_model_v2_lr1.pth")
        best_val_loss = val_loss


Total parameters in the model: 132725760


  0%|          | 0/100 [00:00<?, ?it/s]

Batch: 1
Batch: 101
Batch: 201
Batch: 301
Epoch 1 | Train Loss: 0.03836502712996056 | Validation Loss: 0.03810320022773552


  1%|          | 1/100 [04:25<7:18:41, 265.87s/it]

Batch: 1
Batch: 101
Batch: 201
Batch: 301
Epoch 2 | Train Loss: 0.017532667429703806 | Validation Loss: 0.03760222710193472


  2%|▏         | 2/100 [08:47<7:09:43, 263.10s/it]

Batch: 1
Batch: 101
Batch: 201
Batch: 301


  3%|▎         | 3/100 [13:09<7:05:01, 262.90s/it]

Epoch 3 | Train Loss: 0.010983124284151321 | Validation Loss: 0.04393888791983432
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  4%|▍         | 4/100 [17:27<6:57:39, 261.04s/it]

Epoch 4 | Train Loss: 0.009047266460321326 | Validation Loss: 0.042031212016306024
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  5%|▌         | 5/100 [21:45<6:51:23, 259.82s/it]

Epoch 5 | Train Loss: 0.008145773812106603 | Validation Loss: 0.04470572684039461
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  6%|▌         | 6/100 [26:07<6:48:03, 260.46s/it]

Epoch 6 | Train Loss: 0.007361653863012583 | Validation Loss: 0.0528841051966586
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  7%|▋         | 7/100 [30:25<6:42:28, 259.66s/it]

Epoch 7 | Train Loss: 0.0069409509228232 | Validation Loss: 0.04483182195256999
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  8%|▊         | 8/100 [34:45<6:38:25, 259.84s/it]

Epoch 8 | Train Loss: 0.0066320556791551 | Validation Loss: 0.049901716985759584
Batch: 1
Batch: 101
Batch: 201
Batch: 301


  9%|▉         | 9/100 [39:06<6:34:35, 260.17s/it]

Epoch 9 | Train Loss: 0.006080855336040259 | Validation Loss: 0.05398279103509923
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 10%|█         | 10/100 [43:20<6:27:29, 258.33s/it]

Epoch 10 | Train Loss: 0.0058590529172862366 | Validation Loss: 0.05486795157590445
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 11%|█         | 11/100 [47:32<6:20:16, 256.37s/it]

Epoch 11 | Train Loss: 0.005583721027101597 | Validation Loss: 0.05645773125852042
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 12%|█▏        | 12/100 [51:49<6:16:09, 256.47s/it]

Epoch 12 | Train Loss: 0.005252542689170999 | Validation Loss: 0.052621385558171474
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 13%|█▎        | 13/100 [56:02<6:10:22, 255.43s/it]

Epoch 13 | Train Loss: 0.005243390430890334 | Validation Loss: 0.059135470083577835
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 14%|█▍        | 14/100 [1:00:18<6:06:25, 255.64s/it]

Epoch 14 | Train Loss: 0.0051148370144397225 | Validation Loss: 0.05750164209290388
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 15%|█▌        | 15/100 [1:04:35<6:02:56, 256.20s/it]

Epoch 15 | Train Loss: 0.004661339668544113 | Validation Loss: 0.06011492178398878
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 16%|█▌        | 16/100 [1:08:51<5:58:36, 256.15s/it]

Epoch 16 | Train Loss: 0.00447752101835148 | Validation Loss: 0.05677363285398547
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 17%|█▋        | 17/100 [1:13:50<6:11:57, 268.89s/it]

Epoch 17 | Train Loss: 0.004361740897002165 | Validation Loss: 0.061271820455155473
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 18%|█▊        | 18/100 [1:18:07<6:02:44, 265.42s/it]

Epoch 18 | Train Loss: 0.004193833286828723 | Validation Loss: 0.05934846524069918
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 19%|█▉        | 19/100 [1:22:23<5:54:22, 262.50s/it]

Epoch 19 | Train Loss: 0.0040157022171317495 | Validation Loss: 0.061048531449063026
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 20%|██        | 20/100 [1:26:41<5:48:15, 261.20s/it]

Epoch 20 | Train Loss: 0.0038979770214710796 | Validation Loss: 0.061948282426183526
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 21%|██        | 21/100 [1:30:58<5:42:04, 259.80s/it]

Epoch 21 | Train Loss: 0.003686736776499957 | Validation Loss: 0.05856614011002982
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 22%|██▏       | 22/100 [1:35:14<5:36:32, 258.87s/it]

Epoch 22 | Train Loss: 0.0034561421188603466 | Validation Loss: 0.06055234640123362
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 23%|██▎       | 23/100 [1:39:29<5:30:25, 257.48s/it]

Epoch 23 | Train Loss: 0.0033164349509105604 | Validation Loss: 0.06157867552989975
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 24%|██▍       | 24/100 [1:43:39<5:23:30, 255.40s/it]

Epoch 24 | Train Loss: 0.0032505401150653293 | Validation Loss: 0.06337541022754097
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 25%|██▌       | 25/100 [1:47:54<5:19:08, 255.31s/it]

Epoch 25 | Train Loss: 0.003169686512006188 | Validation Loss: 0.06258038904993458
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 26%|██▌       | 26/100 [1:52:09<5:14:36, 255.09s/it]

Epoch 26 | Train Loss: 0.0031056049362329454 | Validation Loss: 0.06190154589514466
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 27%|██▋       | 27/100 [1:56:19<5:08:38, 253.67s/it]

Epoch 27 | Train Loss: 0.0029776057112030686 | Validation Loss: 0.06528336844070161
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 28%|██▊       | 28/100 [2:00:36<5:05:22, 254.49s/it]

Epoch 28 | Train Loss: 0.002881760100555968 | Validation Loss: 0.056083454730662896
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 29%|██▉       | 29/100 [2:04:51<5:01:22, 254.68s/it]

Epoch 29 | Train Loss: 0.0027602439592657093 | Validation Loss: 0.057898518599649056
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 30%|███       | 30/100 [2:09:03<4:56:11, 253.88s/it]

Epoch 30 | Train Loss: 0.0027361942924067585 | Validation Loss: 0.064347780547402
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 31%|███       | 31/100 [2:15:48<5:44:06, 299.22s/it]

Epoch 31 | Train Loss: 0.002730362086569787 | Validation Loss: 0.06158929173537391
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 32%|███▏      | 32/100 [2:21:39<5:56:49, 314.85s/it]

Epoch 32 | Train Loss: 0.002592412154869332 | Validation Loss: 0.06455616224953468
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 33%|███▎      | 33/100 [2:26:12<5:37:32, 302.27s/it]

Epoch 33 | Train Loss: 0.002475785572217622 | Validation Loss: 0.060219374148452534
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 34%|███▍      | 34/100 [2:30:41<5:21:40, 292.43s/it]

Epoch 34 | Train Loss: 0.002378760879380732 | Validation Loss: 0.0702144230537592
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 35%|███▌      | 35/100 [2:35:12<5:09:45, 285.93s/it]

Epoch 35 | Train Loss: 0.002470096041603635 | Validation Loss: 0.0632582888641256
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 36%|███▌      | 36/100 [2:39:48<5:01:35, 282.74s/it]

Epoch 36 | Train Loss: 0.002263089244661387 | Validation Loss: 0.06471711920296892
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 37%|███▋      | 37/100 [2:44:21<4:53:54, 279.91s/it]

Epoch 37 | Train Loss: 0.0022074550032104728 | Validation Loss: 0.06814944914522324
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 38%|███▊      | 38/100 [2:48:59<4:48:42, 279.39s/it]

Epoch 38 | Train Loss: 0.0021116287646388327 | Validation Loss: 0.06632244854452128
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 39%|███▉      | 39/100 [2:53:38<4:44:02, 279.39s/it]

Epoch 39 | Train Loss: 0.002085785892753241 | Validation Loss: 0.06388280168175697
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 40%|████      | 40/100 [2:58:19<4:39:40, 279.67s/it]

Epoch 40 | Train Loss: 0.0020659903769329603 | Validation Loss: 0.06465654827139163
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 41%|████      | 41/100 [3:03:05<4:36:49, 281.52s/it]

Epoch 41 | Train Loss: 0.0019746832202751346 | Validation Loss: 0.06656675016943445
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 42%|████▏     | 42/100 [3:07:43<4:31:11, 280.54s/it]

Epoch 42 | Train Loss: 0.0020044118815955394 | Validation Loss: 0.06454201979919318
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 43%|████▎     | 43/100 [3:12:20<4:25:33, 279.53s/it]

Epoch 43 | Train Loss: 0.0019378603458689112 | Validation Loss: 0.06608711244498795
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 44%|████▍     | 44/100 [3:16:54<4:19:27, 277.99s/it]

Epoch 44 | Train Loss: 0.0018162438704166562 | Validation Loss: 0.06748302114453722
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 45%|████▌     | 45/100 [3:21:28<4:13:40, 276.73s/it]

Epoch 45 | Train Loss: 0.001824759027254509 | Validation Loss: 0.06607662243372266
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 46%|████▌     | 46/100 [3:26:04<4:08:44, 276.37s/it]

Epoch 46 | Train Loss: 0.0018585893196157283 | Validation Loss: 0.06878139058801722
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 47%|████▋     | 47/100 [3:30:41<4:04:27, 276.74s/it]

Epoch 47 | Train Loss: 0.001814536408427456 | Validation Loss: 0.06818055252524767
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 48%|████▊     | 48/100 [3:35:21<4:00:42, 277.74s/it]

Epoch 48 | Train Loss: 0.0016787375177146815 | Validation Loss: 0.06490581589651868
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 49%|████▉     | 49/100 [3:40:02<3:56:52, 278.67s/it]

Epoch 49 | Train Loss: 0.001567385772068519 | Validation Loss: 0.0663019097707373
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 50%|█████     | 50/100 [3:44:37<3:51:21, 277.63s/it]

Epoch 50 | Train Loss: 0.001495040969247283 | Validation Loss: 0.06729608203502412
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 51%|█████     | 51/100 [3:49:26<3:49:21, 280.85s/it]

Epoch 51 | Train Loss: 0.00148778296053125 | Validation Loss: 0.06593306858013286
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 52%|█████▏    | 52/100 [3:54:05<3:44:18, 280.38s/it]

Epoch 52 | Train Loss: 0.0014447606804120976 | Validation Loss: 0.06765477387036414
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 53%|█████▎    | 53/100 [3:58:47<3:40:01, 280.89s/it]

Epoch 53 | Train Loss: 0.0015161958890934735 | Validation Loss: 0.06755695362238491
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 54%|█████▍    | 54/100 [4:03:27<3:35:02, 280.48s/it]

Epoch 54 | Train Loss: 0.0015867956208239775 | Validation Loss: 0.06608286079891185
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 55%|█████▌    | 55/100 [4:08:06<3:30:07, 280.17s/it]

Epoch 55 | Train Loss: 0.001525990008263357 | Validation Loss: 0.06868346358471095
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 56%|█████▌    | 56/100 [4:12:43<3:24:49, 279.31s/it]

Epoch 56 | Train Loss: 0.0014183901563228573 | Validation Loss: 0.06859928224869866
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 57%|█████▋    | 57/100 [4:17:20<3:19:29, 278.36s/it]

Epoch 57 | Train Loss: 0.0013116323609816997 | Validation Loss: 0.06857718925606063
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 58%|█████▊    | 58/100 [4:22:00<3:15:11, 278.85s/it]

Epoch 58 | Train Loss: 0.0012833971843671558 | Validation Loss: 0.06556503207204824
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 59%|█████▉    | 59/100 [4:26:40<3:10:52, 279.34s/it]

Epoch 59 | Train Loss: 0.0012237443739852299 | Validation Loss: 0.07202481866834011
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 60%|██████    | 60/100 [4:31:20<3:06:22, 279.57s/it]

Epoch 60 | Train Loss: 0.0012666471770595915 | Validation Loss: 0.06985571528685854
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 61%|██████    | 61/100 [4:36:01<3:02:00, 280.00s/it]

Epoch 61 | Train Loss: 0.0013087066329009961 | Validation Loss: 0.06816298811835178
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 62%|██████▏   | 62/100 [4:40:39<2:56:58, 279.44s/it]

Epoch 62 | Train Loss: 0.0012631373678838524 | Validation Loss: 0.06894621183659802
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 63%|██████▎   | 63/100 [4:46:38<3:07:01, 303.28s/it]

Epoch 63 | Train Loss: 0.0012223425776432526 | Validation Loss: 0.07105341175214407
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 64%|██████▍   | 64/100 [4:53:49<3:24:53, 341.49s/it]

Epoch 64 | Train Loss: 0.0012302796436415519 | Validation Loss: 0.07072502295387552
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 65%|██████▌   | 65/100 [5:01:16<3:37:46, 373.33s/it]

Epoch 65 | Train Loss: 0.0012164581069858589 | Validation Loss: 0.06839422053320611
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 66%|██████▌   | 66/100 [5:08:52<3:45:32, 398.01s/it]

Epoch 66 | Train Loss: 0.0012547166379388525 | Validation Loss: 0.06931078992784023
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 67%|██████▋   | 67/100 [5:20:32<4:28:39, 488.46s/it]

Epoch 67 | Train Loss: 0.0010938748719329144 | Validation Loss: 0.06951641269583017
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 68%|██████▊   | 68/100 [5:28:43<4:20:59, 489.37s/it]

Epoch 68 | Train Loss: 0.0010289351156744589 | Validation Loss: 0.07066011878641996
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 69%|██████▉   | 69/100 [5:37:31<4:18:53, 501.07s/it]

Epoch 69 | Train Loss: 0.0010225168175465014 | Validation Loss: 0.07002392006998366
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 70%|███████   | 70/100 [5:45:31<4:07:22, 494.74s/it]

Epoch 70 | Train Loss: 0.0010172858493913534 | Validation Loss: 0.06916271617754977
Batch: 1
Batch: 101
Batch: 201
Batch: 301


 71%|███████   | 71/100 [5:53:02<3:52:44, 481.54s/it]

Epoch 71 | Train Loss: 0.000997507691580621 | Validation Loss: 0.0707069143177347
Batch: 1


 71%|███████   | 71/100 [5:55:06<2:25:02, 300.09s/it]


KeyboardInterrupt: 

In [103]:
torch.save(model.state_dict(),"model_lr_e5.pth")

In [None]:
from sentence_transformers import SentenceTransformer, InputExample
from torch.utils.data import DataLoader

train_examples = []
for i,row in df_cosine_labels.iterrows():
    train_examples.append(InputExample(texts=[row["sentence1"], row["sentence2"]],label=row["label"]))
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

val_examples = []
for i,row in validation_cosine_labels.iterrows():
    val_examples.append(InputExample(texts=[row["sentence1"], row["sentence2"]],label=row["label"]))


val_dataloader = DataLoader(val_examples, shuffle=True, batch_size=16)

In [110]:

from transformers import DistilBertTokenizer, DistilBertModel
from sentence_transformers import SentenceTransformer, models
from torch import nn
class ShroomformerV2(nn.Module):
    def __init__(self):
        super(ShroomformerV2, self).__init__()
        word_embedding_model = models.Transformer("bert-base-uncased", max_seq_length=256)
        pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
        downsampler_model = models.Dense(
            in_features=pooling_model.get_sentence_embedding_dimension(),
            out_features=256,
            activation_function=nn.Tanh())
        self.twin1 = SentenceTransformer(modules=[word_embedding_model, pooling_model, downsampler_model])
        self.twin2 = SentenceTransformer(modules=[word_embedding_model, pooling_model, downsampler_model])
        self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    

    def forward(self, sentence1_encoded, sentence2_encoded):
        sentence1_embeddings =self.twin1()
        sentence2_embeddings = self.twin2(**sentence2_encoded)
        
        similarity_score = self.cos(sentence1_embeddings,sentence2_embeddings)
        return similarity_score
       



model = ShroomformerV2()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.00001)
num_epochs=1
best_val_loss = float('inf')
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total parameters in the model:", total_params)
for epoch in tqdm(range(num_epochs)):
    running_loss = 0.0
    correct = 0
    total = 0
    model.train()
    for i, data in enumerate(train_dataloader, 0):
        if i%100==0:
            print(f"Batch: {i+1}")
        sentences,labels=data
   

     
        optimizer.zero_grad()
    #     labels = labels.float()
        outputs=model(sentence1,sentence2)
        print(outputs)
        break
    #     loss = criterion(outputs, labels)
    #     loss.backward()
    #     optimizer.step()
    #     running_loss += loss.item()

    # model.eval()
    # val_loss = 0.0
    # with torch.no_grad():
    #     for val_data in val_dataloader:
    #         val_sentences, val_labels = val_data
    #         val_sentence1, val_sentence2 = encode_pair(val_sentences)
    #         val_labels = val_labels.float()
    #         val_outputs = model(val_sentence1, val_sentence2)
    #         val_loss += criterion(val_outputs, val_labels).item()
    # val_loss /= len(val_dataloader)

    # print(f'Epoch {epoch+1} | Train Loss: {running_loss / len(train_dataloader)} | Validation Loss: {val_loss}')

    # if val_loss < best_val_loss:
    #     torch.save(model.state_dict(), "best_model_v2_lr1.pth")
    #     best_val_loss = val_loss


Total parameters in the model: 109679104


  0%|          | 0/1 [00:00<?, ?it/s]

Batch: 1
{'input_ids': tensor([[  101,  1037,  2450,  4472,  1037,  3538,  1997,  6240,  2007, 13724,
          1012,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [  101, 12849,  3549,  2180,  1005,  1056,  3013,  3740,  6687,  9021,
          4804,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [  101,  2048,  4584,  3322,  2056,  2055,  3998,  3667,  2018,  2042,
          4727,  2012,  6079,  5324,  1999,  2538,  2163,  1012,   102,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0],
        [  101,  1996,  2500,  2020,  2445,  4809,  1997,  1000,  2852,  1012,
         21087,  1005,  2047,  




TypeError: Sequential.forward() got an unexpected keyword argument 'input_ids'

In [112]:
from transformers import AutoTokenizer, AutoModel
import torch


# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask


# Sentences we want sentence embeddings for
sentences = [
    "This framework generates embeddings for each input sentence",
    "Sentences are passed as a list of string.",
    "The quick brown fox jumps over the lazy dog.",
]


word_embedding_model = models.Transformer("bert-base-uncased", max_seq_length=256)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
downsampler_model = models.Dense(
            in_features=pooling_model.get_sentence_embedding_dimension(),
            out_features=256,
            activation_function=nn.Tanh())
twin1 = SentenceTransformer(modules=[word_embedding_model, pooling_model, downsampler_model])
# Load AutoModel from huggingface model repository
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

# Tokenize sentences
encoded_input = tokenizer(
    sentences, padding=True, truncation=True, max_length=128, return_tensors="pt"
)
print(encoded_input)
# Compute token embeddings
with torch.no_grad():
    model_output = twin1(**encoded_input)

# Perform pooling. In this case, mean pooling
sentence_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])

{'input_ids': tensor([[  101,  2023,  7705, 19421,  7861,  8270,  4667,  2015,  2005,  2169,
          7953,  6251,   102],
        [  101, 11746,  2024,  2979,  2004,  1037,  2862,  1997,  5164,  1012,
           102,     0,     0],
        [  101,  1996,  4248,  2829,  4419, 14523,  2058,  1996, 13971,  3899,
          1012,   102,     0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]])}


TypeError: Sequential.forward() got an unexpected keyword argument 'input_ids'