In [35]:
import pandas as pd
import os
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer

## Load data and forward pass

In [47]:
DATA_FOLDER = 'data'
MODEL_FOLDER = 'models'
path = 'ACHIEVEMENT.csv'
train_split = .8
df = pd.read_csv(os.path.join(os.getcwd(), DATA_FOLDER, path))
df_train = df.sample(frac=train_split, random_state=0)
df_test = df.drop(df_train.index).reset_index(drop=True)

In [37]:
VALUES = ['ACHIEVEMENT', 'BENEVOLENCE', 'CONFORMITY', 'HEDONISM', 'POWER', 'SECURITY', 'SELF-DIRECTION', 'STIMULATION', 'TRADITION', 'UNIVERSALISM']

class ValueDataset(Dataset):
    
    def __init__(self, tokenizer, df): 
        self.scenarios = df['scenario'].values.tolist()
        self.N = df.shape[0]

        inp = tokenizer(self.scenarios, return_tensors='pt', padding=True, truncation=True)
        self.input_ids = inp.get('input_ids')
        self.attention_mask = inp.get('attention_mask')
        self.token_type_ids = inp.get('token_type_ids')
        self.target = df['label'].values.tolist()
    
    def __getitem__(self, index):
        return self.input_ids[index], self.attention_mask[index], self.token_type_ids[index], self.target[index]

    def __len__(self):
        return self.N

In [86]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModelForSequenceClassification.from_pretrained('prajjwal1/bert-small', num_labels = 1)

train_dataset = ValueDataset(tokenizer, df_train)
test_dataset = ValueDataset(tokenizer, df_test)

train_dataloader = DataLoader(dataset=train_dataset, batch_size=10, shuffle=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=len(test_dataset))

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-small and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [87]:
(test_dataloader)

171

## Training loop

In [28]:
from torch.nn import MSELoss, Tanh

# Assuming model is already defined and loaded as in your code
loss_fn = MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
tanh = Tanh()

epochs = 200
epoch_loss = []

for e in tqdm(range(epochs)):
    model.train()  # Set the model to training mode
    batch_loss = []

    for input_ids, attention_mask, token_type_ids, targets in train_dataloader:
        # Forward pass
        optimizer.zero_grad()  # Clear gradients
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        logits = outputs.logits
        logits = tanh(logits).squeeze(-1)  # Apply tanh activation to bound outputs

        # Calculate loss
        loss = loss_fn(logits, targets.float())  # Make sure targets are floats
        batch_loss.append(loss.item())

        # Backward pass
        loss.backward()
        optimizer.step()

    # Calculate and print the average loss for the epoch
    avg_loss = np.mean(batch_loss)
    epoch_loss.append(avg_loss)
    tqdm.write(f'Epoch {e+1}/{epochs}, Loss: {avg_loss}')

# Print the loss after all epochs
print(epoch_loss)


  0%|          | 1/200 [00:16<54:20, 16.38s/it]

Epoch 1/200, Loss: 0.5997630135304686


  1%|          | 2/200 [00:33<56:14, 17.04s/it]

Epoch 2/200, Loss: 0.4668541535519172


  2%|▏         | 3/200 [00:48<52:28, 15.98s/it]

Epoch 3/200, Loss: 0.3946911189435185


  2%|▏         | 4/200 [01:03<50:10, 15.36s/it]

Epoch 4/200, Loss: 0.33546844815862353


  2%|▎         | 5/200 [01:17<48:48, 15.02s/it]

Epoch 5/200, Loss: 0.2841189795214197


  3%|▎         | 6/200 [01:32<48:54, 15.13s/it]

Epoch 6/200, Loss: 0.23198208374821622


  4%|▎         | 7/200 [01:47<48:31, 15.08s/it]

Epoch 7/200, Loss: 0.21820684101270593


  4%|▍         | 8/200 [02:02<47:49, 14.95s/it]

Epoch 8/200, Loss: 0.1681743427247241


  4%|▍         | 9/200 [02:16<47:00, 14.77s/it]

Epoch 9/200, Loss: 0.15145094719701918


  5%|▌         | 10/200 [02:31<46:25, 14.66s/it]

Epoch 10/200, Loss: 0.13406356083958046


  6%|▌         | 11/200 [02:45<46:00, 14.60s/it]

Epoch 11/200, Loss: 0.14118144591001497


  6%|▌         | 12/200 [03:00<46:02, 14.69s/it]

Epoch 12/200, Loss: 0.1150017509209937


  6%|▋         | 13/200 [03:14<45:30, 14.60s/it]

Epoch 13/200, Loss: 0.10545471747932227


  7%|▋         | 14/200 [03:29<45:02, 14.53s/it]

Epoch 14/200, Loss: 0.08475685637930165


  8%|▊         | 15/200 [03:43<44:45, 14.52s/it]

Epoch 15/200, Loss: 0.08123993673834248


  8%|▊         | 16/200 [03:58<44:27, 14.50s/it]

Epoch 16/200, Loss: 0.06970191827934721


  8%|▊         | 17/200 [04:12<44:14, 14.51s/it]

Epoch 17/200, Loss: 0.07586110834086286


  9%|▉         | 18/200 [04:28<44:56, 14.82s/it]

Epoch 18/200, Loss: 0.06474312705298264


 10%|▉         | 19/200 [04:43<44:34, 14.78s/it]

Epoch 19/200, Loss: 0.05759649798004092


 10%|█         | 20/200 [04:57<44:15, 14.75s/it]

Epoch 20/200, Loss: 0.06193342135436293


 10%|█         | 21/200 [05:12<43:41, 14.65s/it]

Epoch 21/200, Loss: 0.054157888219840286


 11%|█         | 22/200 [05:26<43:16, 14.59s/it]

Epoch 22/200, Loss: 0.0508661035961215


 12%|█▏        | 23/200 [05:41<42:55, 14.55s/it]

Epoch 23/200, Loss: 0.05101783485219315


 12%|█▏        | 24/200 [05:55<42:40, 14.55s/it]

Epoch 24/200, Loss: 0.044596257147149765


 12%|█▎        | 25/200 [06:09<42:17, 14.50s/it]

Epoch 25/200, Loss: 0.047763248986524086


 13%|█▎        | 26/200 [06:25<42:43, 14.73s/it]

Epoch 26/200, Loss: 0.04238380548422751


 14%|█▎        | 27/200 [06:39<42:22, 14.69s/it]

Epoch 27/200, Loss: 0.043172559120516846


 14%|█▍        | 28/200 [06:54<42:09, 14.71s/it]

Epoch 28/200, Loss: 0.04165002719863602


 14%|█▍        | 29/200 [07:08<41:36, 14.60s/it]

Epoch 29/200, Loss: 0.03969760822213214


 15%|█▌        | 30/200 [07:23<41:10, 14.53s/it]

Epoch 30/200, Loss: 0.035623029414294426


 16%|█▌        | 31/200 [07:37<40:47, 14.48s/it]

Epoch 31/200, Loss: 0.029680487375868404


 16%|█▌        | 32/200 [07:52<40:37, 14.51s/it]

Epoch 32/200, Loss: 0.03372379000960053


 16%|█▋        | 33/200 [08:06<40:16, 14.47s/it]

Epoch 33/200, Loss: 0.03454692482921309


 17%|█▋        | 34/200 [08:20<39:58, 14.45s/it]

Epoch 34/200, Loss: 0.03340187912428941


 18%|█▊        | 35/200 [08:35<39:44, 14.45s/it]

Epoch 35/200, Loss: 0.027913269772450778


 18%|█▊        | 36/200 [08:51<40:29, 14.82s/it]

Epoch 36/200, Loss: 0.029520464494176533


 18%|█▊        | 37/200 [09:05<40:12, 14.80s/it]

Epoch 37/200, Loss: 0.026323313855201654


 19%|█▉        | 38/200 [09:20<39:55, 14.79s/it]

Epoch 38/200, Loss: 0.028602539305237755


 20%|█▉        | 39/200 [09:35<39:21, 14.67s/it]

Epoch 39/200, Loss: 0.028244951739907265


 20%|██        | 40/200 [09:50<39:28, 14.80s/it]

Epoch 40/200, Loss: 0.028024526501911274


 20%|██        | 41/200 [10:05<39:58, 15.08s/it]

Epoch 41/200, Loss: 0.02603353997287543


 21%|██        | 42/200 [10:20<39:14, 14.90s/it]

Epoch 42/200, Loss: 0.026228947149913594


 22%|██▏       | 43/200 [10:34<38:40, 14.78s/it]

Epoch 43/200, Loss: 0.0277674147096611


 22%|██▏       | 44/200 [10:49<38:10, 14.68s/it]

Epoch 44/200, Loss: 0.026513087873657543


 22%|██▎       | 45/200 [11:03<37:42, 14.60s/it]

Epoch 45/200, Loss: 0.02600482932926304


 23%|██▎       | 46/200 [11:18<37:15, 14.52s/it]

Epoch 46/200, Loss: 0.022034280051045334


 24%|██▎       | 47/200 [11:32<37:01, 14.52s/it]

Epoch 47/200, Loss: 0.018559926030887425


 24%|██▍       | 48/200 [11:47<36:44, 14.50s/it]

Epoch 48/200, Loss: 0.021965512158889054


 24%|██▍       | 49/200 [12:01<36:24, 14.46s/it]

Epoch 49/200, Loss: 0.018766460059534595


 25%|██▌       | 50/200 [12:15<36:04, 14.43s/it]

Epoch 50/200, Loss: 0.01825523070073214


 26%|██▌       | 51/200 [12:30<35:46, 14.41s/it]

Epoch 51/200, Loss: 0.016812701773681285


 26%|██▌       | 52/200 [12:44<35:34, 14.42s/it]

Epoch 52/200, Loss: 0.018049068077453885


 26%|██▋       | 53/200 [12:59<35:21, 14.43s/it]

Epoch 53/200, Loss: 0.019993387277413538


 27%|██▋       | 54/200 [13:13<35:02, 14.40s/it]

Epoch 54/200, Loss: 0.018071123938737572


 28%|██▊       | 55/200 [13:27<34:54, 14.44s/it]

Epoch 55/200, Loss: 0.020118407749881346


 28%|██▊       | 56/200 [13:42<34:38, 14.43s/it]

Epoch 56/200, Loss: 0.01720493252767061


 28%|██▊       | 57/200 [13:56<34:34, 14.51s/it]

Epoch 57/200, Loss: 0.014939334020828423


 29%|██▉       | 58/200 [14:11<34:14, 14.47s/it]

Epoch 58/200, Loss: 0.0151109896445026


 30%|██▉       | 59/200 [14:25<33:54, 14.43s/it]

Epoch 59/200, Loss: 0.014132588884145345


 30%|███       | 60/200 [14:40<33:40, 14.43s/it]

Epoch 60/200, Loss: 0.01657594365162262


 30%|███       | 61/200 [14:54<33:30, 14.46s/it]

Epoch 61/200, Loss: 0.015368368172937113


 31%|███       | 62/200 [15:09<33:14, 14.45s/it]

Epoch 62/200, Loss: 0.013484698838378856


 32%|███▏      | 63/200 [15:23<33:05, 14.49s/it]

Epoch 63/200, Loss: 0.012948967627776057


 32%|███▏      | 64/200 [15:38<32:45, 14.45s/it]

Epoch 64/200, Loss: 0.012070991078634625


 32%|███▎      | 65/200 [15:52<32:28, 14.43s/it]

Epoch 65/200, Loss: 0.015402793712185128


 33%|███▎      | 66/200 [16:06<32:10, 14.41s/it]

Epoch 66/200, Loss: 0.013645333084412783


 34%|███▎      | 67/200 [16:21<31:53, 14.38s/it]

Epoch 67/200, Loss: 0.01738798795470401


 34%|███▍      | 68/200 [16:35<31:42, 14.41s/it]

Epoch 68/200, Loss: 0.014936539594192003


 34%|███▍      | 69/200 [16:49<31:22, 14.37s/it]

Epoch 69/200, Loss: 0.01478103036398365


 35%|███▌      | 70/200 [17:04<31:04, 14.34s/it]

Epoch 70/200, Loss: 0.010253183788322993


 36%|███▌      | 71/200 [17:18<30:54, 14.38s/it]

Epoch 71/200, Loss: 0.011381984760890296


 36%|███▌      | 72/200 [17:33<30:43, 14.40s/it]

Epoch 72/200, Loss: 0.010803611135866115


 36%|███▋      | 73/200 [17:47<30:29, 14.41s/it]

Epoch 73/200, Loss: 0.010112200443551916


 37%|███▋      | 74/200 [18:01<30:16, 14.41s/it]

Epoch 74/200, Loss: 0.009930856677744052


 38%|███▊      | 75/200 [18:17<31:04, 14.92s/it]

Epoch 75/200, Loss: 0.009880488293002481


 38%|███▊      | 76/200 [18:36<33:01, 15.98s/it]

Epoch 76/200, Loss: 0.011906213685532735


 38%|███▊      | 77/200 [18:52<32:40, 15.94s/it]

Epoch 77/200, Loss: 0.010225098214847832


 39%|███▉      | 78/200 [19:07<32:02, 15.75s/it]

Epoch 78/200, Loss: 0.010512857692842574


 40%|███▉      | 79/200 [19:22<31:12, 15.47s/it]

Epoch 79/200, Loss: 0.008816893202374164


 40%|████      | 80/200 [19:37<30:30, 15.26s/it]

Epoch 80/200, Loss: 0.008138110778510463


 40%|████      | 81/200 [19:51<29:55, 15.09s/it]

Epoch 81/200, Loss: 0.008431458743660292


 41%|████      | 82/200 [20:06<29:16, 14.88s/it]

Epoch 82/200, Loss: 0.010366187670497575


 42%|████▏     | 83/200 [20:20<28:43, 14.73s/it]

Epoch 83/200, Loss: 0.008871567013767966


 42%|████▏     | 84/200 [20:35<28:19, 14.65s/it]

Epoch 84/200, Loss: 0.00983579698578873


 42%|████▎     | 85/200 [20:49<28:03, 14.64s/it]

Epoch 85/200, Loss: 0.007518714585719441


 43%|████▎     | 86/200 [21:04<27:43, 14.59s/it]

Epoch 86/200, Loss: 0.006133479321800658


 44%|████▎     | 87/200 [21:19<27:51, 14.80s/it]

Epoch 87/200, Loss: 0.007557813499959699


 44%|████▍     | 88/200 [21:34<27:46, 14.88s/it]

Epoch 88/200, Loss: 0.007258889359840448


 44%|████▍     | 89/200 [21:49<27:27, 14.84s/it]

Epoch 89/200, Loss: 0.007235257375135046


 45%|████▌     | 90/200 [22:03<27:03, 14.76s/it]

Epoch 90/200, Loss: 0.009331946920913955


 46%|████▌     | 91/200 [22:18<26:37, 14.66s/it]

Epoch 91/200, Loss: 0.008147197061866198


 46%|████▌     | 92/200 [22:32<26:18, 14.62s/it]

Epoch 92/200, Loss: 0.0061934301626499055


 46%|████▋     | 93/200 [22:47<26:06, 14.64s/it]

Epoch 93/200, Loss: 0.007178361905971542


 47%|████▋     | 94/200 [23:01<25:45, 14.58s/it]

Epoch 94/200, Loss: 0.0071294596490443455


 48%|████▊     | 95/200 [23:16<25:26, 14.54s/it]

Epoch 95/200, Loss: 0.006835236311746775


 48%|████▊     | 96/200 [23:30<25:03, 14.45s/it]

Epoch 96/200, Loss: 0.007431231980276146


 48%|████▊     | 97/200 [23:44<24:40, 14.37s/it]

Epoch 97/200, Loss: 0.0065647223694384966


 49%|████▉     | 98/200 [23:59<24:19, 14.31s/it]

Epoch 98/200, Loss: 0.00673443406307395


 50%|████▉     | 99/200 [24:13<23:58, 14.24s/it]

Epoch 99/200, Loss: 0.007479673459032631


 50%|█████     | 100/200 [24:28<24:13, 14.53s/it]

Epoch 100/200, Loss: 0.006139694600858712


 50%|█████     | 101/200 [24:42<24:00, 14.55s/it]

Epoch 101/200, Loss: 0.006406188290397488


 51%|█████     | 102/200 [24:57<23:40, 14.50s/it]

Epoch 102/200, Loss: 0.0062829695989553265


 52%|█████▏    | 103/200 [25:11<23:20, 14.44s/it]

Epoch 103/200, Loss: 0.0061574689681182845


 52%|█████▏    | 104/200 [25:25<23:03, 14.41s/it]

Epoch 104/200, Loss: 0.006505410994261341


 52%|█████▎    | 105/200 [25:40<22:48, 14.41s/it]

Epoch 105/200, Loss: 0.0053214458198256225


 53%|█████▎    | 106/200 [25:54<22:38, 14.45s/it]

Epoch 106/200, Loss: 0.006871988227509934


 54%|█████▎    | 107/200 [26:09<22:24, 14.46s/it]

Epoch 107/200, Loss: 0.005545746251616789


 54%|█████▍    | 108/200 [26:23<22:11, 14.47s/it]

Epoch 108/200, Loss: 0.005924939888232536


 55%|█████▍    | 109/200 [26:38<22:07, 14.59s/it]

Epoch 109/200, Loss: 0.007314506768876606


 55%|█████▌    | 110/200 [26:53<21:51, 14.57s/it]

Epoch 110/200, Loss: 0.00568874223635235


 56%|█████▌    | 111/200 [27:07<21:32, 14.53s/it]

Epoch 111/200, Loss: 0.005957105801022355


 56%|█████▌    | 112/200 [27:23<21:51, 14.91s/it]

Epoch 112/200, Loss: 0.004286731732503065


 56%|█████▋    | 113/200 [27:38<21:48, 15.04s/it]

Epoch 113/200, Loss: 0.005230011385983612


 57%|█████▋    | 114/200 [27:53<21:31, 15.02s/it]

Epoch 114/200, Loss: 0.005671752923407821


 57%|█████▊    | 115/200 [28:08<21:04, 14.88s/it]

Epoch 115/200, Loss: 0.005263866153140755


 58%|█████▊    | 116/200 [28:23<20:48, 14.86s/it]

Epoch 116/200, Loss: 0.005148563119819037


 58%|█████▊    | 117/200 [28:37<20:26, 14.78s/it]

Epoch 117/200, Loss: 0.005374059868989972


 59%|█████▉    | 118/200 [28:52<20:03, 14.68s/it]

Epoch 118/200, Loss: 0.0072992570340579405


 60%|█████▉    | 119/200 [29:06<19:44, 14.63s/it]

Epoch 119/200, Loss: 0.006574532662879379


 60%|██████    | 120/200 [29:21<19:26, 14.58s/it]

Epoch 120/200, Loss: 0.005090595571411963


 60%|██████    | 121/200 [29:35<19:09, 14.55s/it]

Epoch 121/200, Loss: 0.0050309571790579785


 61%|██████    | 122/200 [29:50<19:00, 14.63s/it]

Epoch 122/200, Loss: 0.005830703108348762


 62%|██████▏   | 123/200 [30:04<18:43, 14.59s/it]

Epoch 123/200, Loss: 0.0069255427928958625


 62%|██████▏   | 124/200 [30:20<18:42, 14.77s/it]

Epoch 124/200, Loss: 0.004600719941874453


 62%|██████▎   | 125/200 [30:35<18:31, 14.82s/it]

Epoch 125/200, Loss: 0.006319335504077321


 63%|██████▎   | 126/200 [30:49<18:17, 14.83s/it]

Epoch 126/200, Loss: 0.0064829683780798395


 64%|██████▎   | 127/200 [31:04<17:55, 14.74s/it]

Epoch 127/200, Loss: 0.00583541669344624


 64%|██████▍   | 128/200 [31:19<17:38, 14.71s/it]

Epoch 128/200, Loss: 0.006730860186297028


 64%|██████▍   | 129/200 [31:33<17:19, 14.64s/it]

Epoch 129/200, Loss: 0.006879392543772195


 65%|██████▌   | 130/200 [31:48<17:06, 14.66s/it]

Epoch 130/200, Loss: 0.005956162858005289


 66%|██████▌   | 131/200 [32:02<16:48, 14.62s/it]

Epoch 131/200, Loss: 0.0043680839250252275


 66%|██████▌   | 132/200 [32:17<16:35, 14.64s/it]

Epoch 132/200, Loss: 0.004641107195377539


 66%|██████▋   | 133/200 [32:32<16:19, 14.62s/it]

Epoch 133/200, Loss: 0.004573526016336974


 67%|██████▋   | 134/200 [32:46<16:05, 14.62s/it]

Epoch 134/200, Loss: 0.004786580517142336


 68%|██████▊   | 135/200 [33:01<15:49, 14.61s/it]

Epoch 135/200, Loss: 0.0043108052731427515


 68%|██████▊   | 136/200 [33:15<15:35, 14.61s/it]

Epoch 136/200, Loss: 0.004137869705395449


 68%|██████▊   | 137/200 [33:31<15:39, 14.91s/it]

Epoch 137/200, Loss: 0.004417795384010853


 69%|██████▉   | 138/200 [33:46<15:22, 14.88s/it]

Epoch 138/200, Loss: 0.004370434861166227


 70%|██████▉   | 139/200 [34:01<15:05, 14.84s/it]

Epoch 139/200, Loss: 0.0035093051357788668


 70%|███████   | 140/200 [34:15<14:45, 14.76s/it]

Epoch 140/200, Loss: 0.004076552372746359


 70%|███████   | 141/200 [34:30<14:29, 14.74s/it]

Epoch 141/200, Loss: 0.003717032970440493


 71%|███████   | 142/200 [34:44<14:11, 14.69s/it]

Epoch 142/200, Loss: 0.003661657769597419


 72%|███████▏  | 143/200 [34:59<13:54, 14.65s/it]

Epoch 143/200, Loss: 0.004315061657942589


 72%|███████▏  | 144/200 [35:13<13:37, 14.59s/it]

Epoch 144/200, Loss: 0.007032933493947907


 72%|███████▎  | 145/200 [35:28<13:20, 14.56s/it]

Epoch 145/200, Loss: 0.00497712647231073


 73%|███████▎  | 146/200 [35:42<13:05, 14.55s/it]

Epoch 146/200, Loss: 0.005782174668190461


 74%|███████▎  | 147/200 [35:57<12:49, 14.52s/it]

Epoch 147/200, Loss: 0.004013818694173993


 74%|███████▍  | 148/200 [36:11<12:35, 14.52s/it]

Epoch 148/200, Loss: 0.00444119119345197


 74%|███████▍  | 149/200 [36:27<12:35, 14.82s/it]

Epoch 149/200, Loss: 0.003250602024954919


 75%|███████▌  | 150/200 [36:42<12:25, 14.90s/it]

Epoch 150/200, Loss: 0.003640107325855238


 76%|███████▌  | 151/200 [36:57<12:09, 14.90s/it]

Epoch 151/200, Loss: 0.004030598217007313


 76%|███████▌  | 152/200 [37:11<11:49, 14.78s/it]

Epoch 152/200, Loss: 0.00596858219469668


 76%|███████▋  | 153/200 [37:26<11:31, 14.70s/it]

Epoch 153/200, Loss: 0.004094030557072087


 77%|███████▋  | 154/200 [37:40<11:13, 14.64s/it]

Epoch 154/200, Loss: 0.004342160031386732


 78%|███████▊  | 155/200 [37:55<10:58, 14.64s/it]

Epoch 155/200, Loss: 0.0039893694420638694


 78%|███████▊  | 156/200 [38:10<10:42, 14.59s/it]

Epoch 156/200, Loss: 0.0036453122991825576


 78%|███████▊  | 157/200 [38:24<10:25, 14.54s/it]

Epoch 157/200, Loss: 0.003749298496151586


 79%|███████▉  | 158/200 [38:39<10:12, 14.58s/it]

Epoch 158/200, Loss: 0.003766392973678497


 80%|███████▉  | 159/200 [38:53<10:00, 14.64s/it]

Epoch 159/200, Loss: 0.003444701953254559


 80%|████████  | 160/200 [39:08<09:44, 14.61s/it]

Epoch 160/200, Loss: 0.003047954298064882


 80%|████████  | 161/200 [39:23<09:36, 14.79s/it]

Epoch 161/200, Loss: 0.0033085252961019


 81%|████████  | 162/200 [39:38<09:22, 14.80s/it]

Epoch 162/200, Loss: 0.0034479025272193594


 82%|████████▏ | 163/200 [39:53<09:09, 14.86s/it]

Epoch 163/200, Loss: 0.004039046637616053


 82%|████████▏ | 164/200 [40:08<08:53, 14.81s/it]

Epoch 164/200, Loss: 0.00401805393683363


 82%|████████▎ | 165/200 [40:22<08:34, 14.71s/it]

Epoch 165/200, Loss: 0.003798224619168795


 83%|████████▎ | 166/200 [40:37<08:18, 14.66s/it]

Epoch 166/200, Loss: 0.003876725637963028


 84%|████████▎ | 167/200 [40:51<08:02, 14.61s/it]

Epoch 167/200, Loss: 0.003216508911029213


 84%|████████▍ | 168/200 [41:06<07:47, 14.59s/it]

Epoch 168/200, Loss: 0.003026746414632172


 84%|████████▍ | 169/200 [41:20<07:30, 14.54s/it]

Epoch 169/200, Loss: 0.0035175121487607585


 85%|████████▌ | 170/200 [41:34<07:13, 14.46s/it]

Epoch 170/200, Loss: 0.0030600516346158406


 86%|████████▌ | 171/200 [41:49<07:01, 14.53s/it]

Epoch 171/200, Loss: 0.004191303084645778


 86%|████████▌ | 172/200 [42:03<06:45, 14.47s/it]

Epoch 172/200, Loss: 0.006832783739124695


 86%|████████▋ | 173/200 [42:18<06:30, 14.46s/it]

Epoch 173/200, Loss: 0.005225277516570793


 87%|████████▋ | 174/200 [42:36<06:44, 15.55s/it]

Epoch 174/200, Loss: 0.003146463855797344


 88%|████████▊ | 175/200 [42:51<06:26, 15.48s/it]

Epoch 175/200, Loss: 0.0034334479335128613


 88%|████████▊ | 176/200 [43:06<06:06, 15.26s/it]

Epoch 176/200, Loss: 0.0029545257243823826


 88%|████████▊ | 177/200 [43:21<05:49, 15.18s/it]

Epoch 177/200, Loss: 0.0036629493973821914


 89%|████████▉ | 178/200 [43:36<05:29, 14.99s/it]

Epoch 178/200, Loss: 0.00503599757148945


 90%|████████▉ | 179/200 [43:50<05:14, 14.95s/it]

Epoch 179/200, Loss: 0.004414567102566766


 90%|█████████ | 180/200 [44:05<04:57, 14.87s/it]

Epoch 180/200, Loss: 0.002991024138611755


 90%|█████████ | 181/200 [44:20<04:40, 14.76s/it]

Epoch 181/200, Loss: 0.0026658011842485303


 91%|█████████ | 182/200 [44:34<04:24, 14.70s/it]

Epoch 182/200, Loss: 0.0035253712874533526


 92%|█████████▏| 183/200 [44:49<04:09, 14.69s/it]

Epoch 183/200, Loss: 0.0031016324686841444


 92%|█████████▏| 184/200 [45:03<03:53, 14.62s/it]

Epoch 184/200, Loss: 0.003556521507351241


 92%|█████████▎| 185/200 [45:18<03:38, 14.59s/it]

Epoch 185/200, Loss: 0.003292371165235122


 93%|█████████▎| 186/200 [45:34<03:31, 15.10s/it]

Epoch 186/200, Loss: 0.0025834395948576444


 94%|█████████▎| 187/200 [45:49<03:15, 15.05s/it]

Epoch 187/200, Loss: 0.0028941707404674557


 94%|█████████▍| 188/200 [46:03<02:58, 14.86s/it]

Epoch 188/200, Loss: 0.002643730352629327


 94%|█████████▍| 189/200 [46:18<02:42, 14.78s/it]

Epoch 189/200, Loss: 0.003220028980412399


 95%|█████████▌| 190/200 [46:33<02:27, 14.78s/it]

Epoch 190/200, Loss: 0.0030554888645133033


 96%|█████████▌| 191/200 [46:47<02:12, 14.71s/it]

Epoch 191/200, Loss: 0.005632995813093789


 96%|█████████▌| 192/200 [47:02<01:57, 14.64s/it]

Epoch 192/200, Loss: 0.004536786459315065


 96%|█████████▋| 193/200 [47:17<01:42, 14.68s/it]

Epoch 193/200, Loss: 0.003153558441503581


 97%|█████████▋| 194/200 [47:32<01:28, 14.77s/it]

Epoch 194/200, Loss: 0.0024297483119037192


 98%|█████████▊| 195/200 [47:46<01:13, 14.79s/it]

Epoch 195/200, Loss: 0.0036615623826366864


 98%|█████████▊| 196/200 [48:01<00:58, 14.65s/it]

Epoch 196/200, Loss: 0.0030389611049980413


 98%|█████████▊| 197/200 [48:15<00:43, 14.64s/it]

Epoch 197/200, Loss: 0.0023273929759127805


 99%|█████████▉| 198/200 [48:31<00:30, 15.00s/it]

Epoch 198/200, Loss: 0.0026587110289677885


100%|█████████▉| 199/200 [48:46<00:15, 15.05s/it]

Epoch 199/200, Loss: 0.0028117373244047326


100%|██████████| 200/200 [49:01<00:00, 14.71s/it]

Epoch 200/200, Loss: 0.002899042176199989
[0.5997630135304686, 0.4668541535519172, 0.3946911189435185, 0.33546844815862353, 0.2841189795214197, 0.23198208374821622, 0.21820684101270593, 0.1681743427247241, 0.15145094719701918, 0.13406356083958046, 0.14118144591001497, 0.1150017509209937, 0.10545471747932227, 0.08475685637930165, 0.08123993673834248, 0.06970191827934721, 0.07586110834086286, 0.06474312705298264, 0.05759649798004092, 0.06193342135436293, 0.054157888219840286, 0.0508661035961215, 0.05101783485219315, 0.044596257147149765, 0.047763248986524086, 0.04238380548422751, 0.043172559120516846, 0.04165002719863602, 0.03969760822213214, 0.035623029414294426, 0.029680487375868404, 0.03372379000960053, 0.03454692482921309, 0.03340187912428941, 0.027913269772450778, 0.029520464494176533, 0.026323313855201654, 0.028602539305237755, 0.028244951739907265, 0.028024526501911274, 0.02603353997287543, 0.026228947149913594, 0.0277674147096611, 0.026513087873657543, 0.02600482932926304, 0.0220




## Testing

In [89]:
from sklearn.metrics import f1_score, recall_score, accuracy_score

def get_predictions():
    model.eval()
    with torch.no_grad():
        for data in tqdm(iter(test_dataloader)):
            input_ids, attention_mask, token_types_ids, targets = data
            inp = {'input_ids' : input_ids, 'attention_mask' : attention_mask, 'token_type_ids' : token_types_ids}
            output = model(**inp)
            logits = output.logits
            predictions = tanh(logits).squeeze(-1)

            return predictions, targets

In [95]:
model = torch.load(os.path.join(os.getcwd(), MODEL_FOLDER, 'bert_achievenemt_finetuned_1.pt'))

predictions, targets = get_predictions()

  0%|          | 0/1 [00:00<?, ?it/s]


In [100]:
predictions

tensor([ 0.7111,  0.3014, -0.1772, -0.9845,  0.3819, -0.7010,  0.9651,  0.8240,
        -0.9924, -0.0203, -0.5209, -0.9924,  0.5603, -0.0478, -0.9897, -0.9885,
        -0.0386, -0.9899, -0.9684, -0.9909,  0.0157, -0.0327, -0.2532, -0.9894,
        -0.9805,  0.0594, -0.5274, -0.0174, -0.9465, -0.3620, -0.1607, -0.0396,
        -0.1207, -0.0403, -0.9914,  0.0590, -0.0224, -0.4950, -0.0397,  0.9047,
        -0.0497,  0.4768, -0.8985, -0.0495, -0.9853, -0.9833, -0.7307,  0.1535,
        -0.9081, -0.9512, -0.9905, -0.0721, -0.9934, -0.9882, -0.9269,  0.9787,
        -0.9856, -0.8335, -0.2841, -0.9167, -0.0183,  0.1737, -0.9943,  0.4433,
        -0.9860, -0.8799, -0.9916, -0.9899, -0.9869, -0.9776, -0.9925, -0.5803,
        -0.0145,  0.0096, -0.9911, -0.6890, -0.9203, -0.9878, -0.5461, -0.9840,
         0.0052, -0.9824, -0.2332, -0.0239, -0.9939, -0.8248,  0.3475, -0.1884,
        -0.9440, -0.9849, -0.0393, -0.1398, -0.0495, -0.0300, -0.5077,  0.0191,
         0.3301,  0.1692,  0.9925, -0.55

In [107]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def regression_to_classification_metrics(predictions, targets):
    # Define thresholds for class assignment based on regression predictions
    thresholds = [-0.5, 0.5]

    # Convert regression predictions to discrete labels [-1, 0, 1]
    labels = np.digitize(predictions, bins=thresholds)

    # Calculate precision, recall, and F1 scores
    accuracy = accuracy_score(targets, labels)
    precision = precision_score(targets, labels, average='weighted')
    recall = recall_score(targets, labels, average='weighted')
    f1 = f1_score(targets, labels, average='weighted')

    return accuracy, f1, recall, precision

# Example usage:
# predictions, targets = get_predictions()
accuracy, f1, recall, precision = regression_to_classification_metrics(predictions = predictions, targets = targets)
print("F1 Score:", f1)
print("Recall Score:", recall)
print("Precision Score:", precision)

F1 Score: 0.22234156820622983
Recall Score: 0.2573099415204678
Precision Score: 0.20230223343398437


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


## OLD

In [None]:
epochs = 30
optimizer = torch.optim.Adam(model.parameters())
loss_fn = torch.nn.functional.mse_loss
batch_loss = []
epoch_loss = []

for e in tqdm(range(epochs)):
    for data in train_dataloader:
        # Unpack the data from the dataloader
        input_ids, attention_mask, token_types_ids, targets = data
        
        # Clear previous gradients
        optimizer.zero_grad()

        # Prepare inputs for the model
        inp = {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'token_type_ids': token_types_ids
        }

        # Forward pass
        output = model(**inp)
        loss = loss_fn(output.logits, targets.float().unsqueeze(1))

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Store loss
        batch_loss.append(loss.item())

    # Calculate mean loss for the epoch
    epoch_loss.append(np.mean(batch_loss))
    batch_loss = []

    # Print the mean loss for the epoch
    tqdm.write(f'Epoch {e+1}/{epochs}, Loss: {epoch_loss[-1]}')

# Print the loss after all epochs
print(epoch_loss)
