In [1]:
from transformers import (
    BertModel, 
    AutoConfig, 
    AutoTokenizer, 
    Trainer,
    TrainingArguments
    )
import torch.nn as nn
import datasets
import csv


In [2]:
import torch
import numpy as np

torch.cuda.is_available()

True

In [3]:
import csv

In [4]:
## Load dataset
feats_fp = open("BERT_X.csv", "r")
labels_fp = open("BERT_y.csv", "r")
feats = csv.reader(feats_fp)
labels = csv.reader(labels_fp)

# skip header
next(feats)
next(labels)

unscaled_data = {'text': [], 'label': []}
nn_data = []
for row in feats:
    unscaled_data['text'].append(row[0].strip().replace("\n", " "))
    nn_feats = [float(col.strip().replace("\n", "")) for col in row[1:]]
    
    
    nn_data.append(nn_feats[-2]) # company employee count only
for i,row in enumerate(labels):
    unscaled_data['label'].append(float(row[0].strip().replace("\n", "")))
nn_data = np.array(nn_data).reshape(-1,1).astype(np.float32)
nn_data[:5]

array([[1171.],
       [  36.],
       [1227.],
       [1260.],
       [  36.]], dtype=float32)

In [5]:
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
scaled_labels = scaler.fit_transform(np.array(unscaled_data['label']).reshape(-1,1)).flatten()
scaled_data = unscaled_data
scaled_data['label'] = scaled_labels

employee_count_scaler = StandardScaler()
scaled_employee_count =employee_count_scaler.fit_transform(nn_data[:,0].reshape(-1,1)).flatten()
nn_data[:,0] = scaled_employee_count
scaled_labels, nn_data[:,0]

(array([-0.58786439, -0.94649659,  2.86819603, ..., -1.04441134,
        -0.86218111, -0.69471581]),
 array([-0.30006495, -0.31242564, -0.2994551 , ..., -0.31238207,
        -0.3128286 , -0.31278503], dtype=float32))

In [6]:
assert len(scaled_data['text']) == len(scaled_data['label']) == nn_data.shape[0]
dataset = datasets.Dataset.from_dict(scaled_data)

In [7]:
print(dataset[0]['text'])
scaler.inverse_transform(np.array([dataset[0]['label']]).reshape(-1,1))

Overview  HearingLife is a national hearing care company and part of the Demant Group, a global leader in hearing healthcare built on a heritage of care, health, and innovation since 1904. HearingLife operates more than 600 hearing care centers across 42 states. We follow a scientific, results-oriented approach to hearing healthcare that is provided by highly skilled and caring professionals. Our vision is to help more people hear better through life-changing hearing health delivered by the best personalized care. This Team Member must uphold the HearingLife Core Values:   We create trust  We are team players  We apply a can-do attitude  We create innovative solutions   Responsibilities  You will help more people hear better by providing clinical expertise to diagnose and treat hearing loss while ensuring a positive patient experience. The Hearing Care Provider acts in accordance with required industry and state professional licensing standards and local practice scope and is responsib

array([[63000.]])

In [8]:
# Tokenize the dataset
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
def preprocess(examples):
    return tokenizer(examples['text'], truncation=True, padding=True)

tokenized_dataset = dataset.map(preprocess, batched=True)

Downloading:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

  0%|          | 0/27 [00:00<?, ?ba/s]

In [9]:
# training_args = TrainingArguments(
#     output_dir='./results',          
#     num_train_epochs=3,              
#     per_device_train_batch_size=16,  
#     learning_rate=5e-5,               
#     warmup_steps=500,                
#     weight_decay=0.01,              
#     logging_dir='./logs',
# )


In [10]:
tokenized_dataset

Dataset({
    features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 26990
})

In [11]:
from torch.utils.data import Dataset

class RegressionDataset(Dataset):
    def __init__(self, input_ids, attention_mask, labels, nn_data):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.labels = labels
        self.nn_data = nn_data

    def __getitem__(self, idx):
        item = {
            'input_ids': torch.tensor(self.input_ids[idx]),
            'attention_mask': torch.tensor(self.attention_mask[idx]),
            'labels': torch.tensor(self.labels[idx]),
            'nn_data': torch.tensor(self.nn_data[idx])
        }
        return item

    def __len__(self):
        return len(self.labels)  # Assuming all data entries have labels

    
    

In [12]:
input_ids = tokenized_dataset['input_ids']
attention_mask = tokenized_dataset['attention_mask']
labels = dataset['label']  # Assuming your labels are in the original dataset
nn_data = nn_data

reg_dataset = RegressionDataset(input_ids, attention_mask, labels, nn_data)
dataloader = torch.utils.data.DataLoader(reg_dataset, batch_size=64, shuffle=True)



In [13]:
reg_dataset[0]['nn_data']

tensor([-0.3001])

In [14]:

model_name = "bert-base-uncased"
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = BertModel.from_pretrained(model_name)

config.problem_type = 'regression'

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [17]:

# Add linear layer
output_size = 1  


# Combine BERT and the linear layer
class BertWithLinear(nn.Module):
    def __init__(self):
        super(BertWithLinear, self).__init__()
        self.bert = bert_model.to('cuda')
    
        self.nn = nn.Sequential(
            nn.Linear(config.hidden_size + len(nn_data[0]) , 128),
            nn.ReLU(),
            nn.Linear(128, 16),
            nn.ReLU(),
            nn.Linear(16, output_size)
        ).to('cuda')
        

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, nn_data = None):
        output = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds)
        # Use pooled output for classification/regression
        pooled_output = output.pooler_output
        nn_data = torch.cat((pooled_output, nn_data), 1)
        nn_output = self.nn(nn_data)
        return nn_output

    

model = BertWithLinear().to('cuda')

In [18]:
# freeze/unfreeze BERT pretrained weights
for param in model.bert.embeddings.parameters():
    param.requires_grad = True
model = nn.DataParallel(model)

In [19]:
model.parameters()

<generator object Module.parameters at 0x2baa3d731308>

In [20]:
loss_fn = nn.MSELoss()  # Mean Squared Error is common for regression
optimizer = torch.optim.Adam([
    {'params': model.module.bert.parameters(), 'lr': 1e-5}, 
    {'params': model.module.nn.parameters(), 'lr': 2e-5} # our neural net
])
num_epochs = 10
    
    
from torch.utils.data import random_split 

train_size = int(0.8 * len(reg_dataset))  # 80% of the dataset for training
val_size = len(reg_dataset) - train_size 

# Create the train and validation datasets
train_dataset, val_dataset = random_split(reg_dataset, [train_size, val_size])

# Create DataLoaders
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)  # No need to shuffle validation
train_losses = []
val_losses = []
# Modify your training loop (add validation)
for epoch in range(num_epochs):
    print(epoch,"/",num_epochs)
    b=0
    ## Training Phase
    model.train()  # Set model to training mode
    current_epoch_train_losses = []
    for batch in train_dataloader:
        input_ids = batch['input_ids'].to('cuda') 
        attention_mask = batch['attention_mask'].to('cuda')
        labels = batch['labels'].to('cuda')
        data = batch['nn_data'].to('cuda')

        optimizer.zero_grad()
        
        outputs = model(input_ids, attention_mask=attention_mask, nn_data=data)
        
        loss = loss_fn(outputs.squeeze(1), labels) # Ensure outputs are single-dimensional
        current_epoch_train_losses.append(loss.item())
        loss.backward()
        optimizer.step() 
        print(f"batch {b} complete. Loss: ",loss.item())
        b+=1
    train_losses.append(sum(current_epoch_train_losses)/len(current_epoch_train_losses))
    ## Validation Phase
    model.eval()   # Set model to evaluation mode
    val_loss = 0
    with torch.no_grad():  # Disable gradient calculation for validation
        for batch in val_dataloader:
            input_ids = batch['input_ids'].to('cuda') 
            attention_mask = batch['attention_mask'].to('cuda')
            labels = batch['labels'].to('cuda')
            data = batch['nn_data'].to('cuda')
            outputs = model(input_ids, attention_mask=attention_mask,nn_data=data)
            val_loss += loss_fn(outputs.squeeze(1), labels).item()

    val_loss /= len(val_dataloader)
    val_losses.append(val_loss)
    print(f"Epoch {epoch} Validation Loss: {val_loss}")

    # checkpoint model
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f"model_checkpoint_{epoch}.pth")


0 / 10
batch 0 complete. Loss:  0.9602559208869934
batch 1 complete. Loss:  1.0854591131210327
batch 2 complete. Loss:  0.9973961114883423
batch 3 complete. Loss:  1.4335261583328247
batch 4 complete. Loss:  1.3671677112579346
batch 5 complete. Loss:  1.295344352722168
batch 6 complete. Loss:  0.7374680638313293
batch 7 complete. Loss:  1.401750087738037
batch 8 complete. Loss:  0.7906998991966248
batch 9 complete. Loss:  0.8172276020050049
batch 10 complete. Loss:  1.7089563608169556
batch 11 complete. Loss:  0.8064703941345215
batch 12 complete. Loss:  0.7174537181854248
batch 13 complete. Loss:  0.865260124206543
batch 14 complete. Loss:  0.5848900079727173
batch 15 complete. Loss:  1.352117657661438
batch 16 complete. Loss:  0.954603374004364
batch 17 complete. Loss:  0.8742499351501465
batch 18 complete. Loss:  0.6452634334564209
batch 19 complete. Loss:  0.82177734375
batch 20 complete. Loss:  0.9450459480285645
batch 21 complete. Loss:  0.6680349111557007
batch 22 complete. Loss

batch 181 complete. Loss:  0.39483755826950073
batch 182 complete. Loss:  0.5694147348403931
batch 183 complete. Loss:  0.4627424478530884
batch 184 complete. Loss:  0.2930580973625183
batch 185 complete. Loss:  0.1919148564338684
batch 186 complete. Loss:  0.3649362325668335
batch 187 complete. Loss:  0.3447125256061554
batch 188 complete. Loss:  0.47085410356521606
batch 189 complete. Loss:  0.5182582139968872
batch 190 complete. Loss:  0.43994611501693726
batch 191 complete. Loss:  0.23894815146923065
batch 192 complete. Loss:  0.6318080425262451
batch 193 complete. Loss:  0.39887145161628723
batch 194 complete. Loss:  0.3728852868080139
batch 195 complete. Loss:  0.25238123536109924
batch 196 complete. Loss:  0.2812676727771759
batch 197 complete. Loss:  0.34678804874420166
batch 198 complete. Loss:  0.5716134309768677
batch 199 complete. Loss:  0.23236222565174103
batch 200 complete. Loss:  0.6836685538291931
batch 201 complete. Loss:  0.2768794298171997
batch 202 complete. Loss: 

batch 358 complete. Loss:  0.35365960001945496
batch 359 complete. Loss:  0.46692293882369995
batch 360 complete. Loss:  0.2560327649116516
batch 361 complete. Loss:  0.24280674755573273
batch 362 complete. Loss:  0.24585723876953125
batch 363 complete. Loss:  0.2919696867465973
batch 364 complete. Loss:  0.19065041840076447
batch 365 complete. Loss:  0.5062116384506226
batch 366 complete. Loss:  0.20257794857025146
batch 367 complete. Loss:  0.2734726369380951
batch 368 complete. Loss:  0.16914090514183044
batch 369 complete. Loss:  0.25578540563583374
batch 370 complete. Loss:  0.18478341400623322
batch 371 complete. Loss:  0.2783803343772888
batch 372 complete. Loss:  0.18059737980365753
batch 373 complete. Loss:  0.20651701092720032
batch 374 complete. Loss:  0.2137545496225357
batch 375 complete. Loss:  0.5165082812309265
batch 376 complete. Loss:  0.3242429494857788
batch 377 complete. Loss:  0.24461817741394043
batch 378 complete. Loss:  0.19612160325050354
batch 379 complete. L

batch 535 complete. Loss:  0.18408045172691345
batch 536 complete. Loss:  0.2844407558441162
batch 537 complete. Loss:  0.35551685094833374
batch 538 complete. Loss:  0.18453240394592285
batch 539 complete. Loss:  0.39302054047584534
batch 540 complete. Loss:  0.5946670174598694
batch 541 complete. Loss:  0.256773442029953
batch 542 complete. Loss:  0.17932510375976562
batch 543 complete. Loss:  0.2490367293357849
batch 544 complete. Loss:  0.6271642446517944
batch 545 complete. Loss:  0.3792073726654053
batch 546 complete. Loss:  0.3869614601135254
batch 547 complete. Loss:  0.3545866012573242
batch 548 complete. Loss:  0.3681737184524536
batch 549 complete. Loss:  0.7042070627212524
batch 550 complete. Loss:  0.23133337497711182
batch 551 complete. Loss:  0.29018229246139526
batch 552 complete. Loss:  0.363040566444397
batch 553 complete. Loss:  0.2969742715358734
batch 554 complete. Loss:  0.3672666847705841
batch 555 complete. Loss:  0.16351954638957977
batch 556 complete. Loss:  0

batch 36 complete. Loss:  0.2930046319961548
batch 37 complete. Loss:  0.23612724244594574
batch 38 complete. Loss:  0.1264781504869461
batch 39 complete. Loss:  0.09751620888710022
batch 40 complete. Loss:  0.2730185389518738
batch 41 complete. Loss:  0.17502833902835846
batch 42 complete. Loss:  0.32565608620643616
batch 43 complete. Loss:  0.22074949741363525
batch 44 complete. Loss:  0.2422146499156952
batch 45 complete. Loss:  0.14602452516555786
batch 46 complete. Loss:  0.17501334846019745
batch 47 complete. Loss:  0.17765726149082184
batch 48 complete. Loss:  0.10285371541976929
batch 49 complete. Loss:  0.16085514426231384
batch 50 complete. Loss:  0.12417875230312347
batch 51 complete. Loss:  0.3295987844467163
batch 52 complete. Loss:  0.14273270964622498
batch 53 complete. Loss:  0.1438235342502594
batch 54 complete. Loss:  0.21680273115634918
batch 55 complete. Loss:  0.2288345843553543
batch 56 complete. Loss:  0.19796226918697357
batch 57 complete. Loss:  0.1463388055562

batch 213 complete. Loss:  0.1279658079147339
batch 214 complete. Loss:  0.17669788002967834
batch 215 complete. Loss:  0.09679155051708221
batch 216 complete. Loss:  0.24854260683059692
batch 217 complete. Loss:  0.23509883880615234
batch 218 complete. Loss:  0.16738805174827576
batch 219 complete. Loss:  0.19928371906280518
batch 220 complete. Loss:  0.1626099944114685
batch 221 complete. Loss:  0.15612855553627014
batch 222 complete. Loss:  0.15812820196151733
batch 223 complete. Loss:  0.16299021244049072
batch 224 complete. Loss:  0.22522187232971191
batch 225 complete. Loss:  0.1851237416267395
batch 226 complete. Loss:  0.12575997412204742
batch 227 complete. Loss:  0.305478036403656
batch 228 complete. Loss:  0.1287393867969513
batch 229 complete. Loss:  0.15206243097782135
batch 230 complete. Loss:  0.2311449646949768
batch 231 complete. Loss:  0.21029558777809143
batch 232 complete. Loss:  0.1321716606616974
batch 233 complete. Loss:  0.13968652486801147
batch 234 complete. L

batch 389 complete. Loss:  0.152694970369339
batch 390 complete. Loss:  0.10878217220306396
batch 391 complete. Loss:  0.2563960552215576
batch 392 complete. Loss:  0.1635170876979828
batch 393 complete. Loss:  0.09470626711845398
batch 394 complete. Loss:  0.1388944536447525
batch 395 complete. Loss:  0.1838860809803009
batch 396 complete. Loss:  0.19242513179779053
batch 397 complete. Loss:  0.18453791737556458
batch 398 complete. Loss:  0.0837194174528122
batch 399 complete. Loss:  0.14358624815940857
batch 400 complete. Loss:  0.11853095144033432
batch 401 complete. Loss:  0.1163649931550026
batch 402 complete. Loss:  0.20838767290115356
batch 403 complete. Loss:  0.13274551928043365
batch 404 complete. Loss:  0.10310275107622147
batch 405 complete. Loss:  0.1175161749124527
batch 406 complete. Loss:  0.12079866230487823
batch 407 complete. Loss:  0.19651912152767181
batch 408 complete. Loss:  0.07774337381124496
batch 409 complete. Loss:  0.1804559975862503
batch 410 complete. Los

batch 565 complete. Loss:  0.14047521352767944
batch 566 complete. Loss:  0.12054614722728729
batch 567 complete. Loss:  0.22010000050067902
batch 568 complete. Loss:  0.13878226280212402
batch 569 complete. Loss:  0.11630702018737793
batch 570 complete. Loss:  0.20643389225006104
batch 571 complete. Loss:  0.19223234057426453
batch 572 complete. Loss:  0.18313387036323547
batch 573 complete. Loss:  0.07815878093242645
batch 574 complete. Loss:  0.08374965935945511
batch 575 complete. Loss:  0.1472361981868744
batch 576 complete. Loss:  0.12632523477077484
batch 577 complete. Loss:  0.13019660115242004
batch 578 complete. Loss:  0.12800468504428864
batch 579 complete. Loss:  0.08949653804302216
batch 580 complete. Loss:  0.14760613441467285
batch 581 complete. Loss:  0.08345597982406616
batch 582 complete. Loss:  0.14344000816345215
batch 583 complete. Loss:  0.10627393424510956
batch 584 complete. Loss:  0.08466748893260956
batch 585 complete. Loss:  0.2168249785900116
batch 586 compl

batch 66 complete. Loss:  0.09765547513961792
batch 67 complete. Loss:  0.15320508182048798
batch 68 complete. Loss:  0.1316205859184265
batch 69 complete. Loss:  0.07785971462726593
batch 70 complete. Loss:  0.18158361315727234
batch 71 complete. Loss:  0.11465819180011749
batch 72 complete. Loss:  0.14952103793621063
batch 73 complete. Loss:  0.09216569364070892
batch 74 complete. Loss:  0.11903417110443115
batch 75 complete. Loss:  0.10805921256542206
batch 76 complete. Loss:  0.08020072430372238
batch 77 complete. Loss:  0.05360613018274307
batch 78 complete. Loss:  0.06786008179187775
batch 79 complete. Loss:  0.05286245048046112
batch 80 complete. Loss:  0.14384183287620544
batch 81 complete. Loss:  0.12761864066123962
batch 82 complete. Loss:  0.16745565831661224
batch 83 complete. Loss:  0.05239507183432579
batch 84 complete. Loss:  0.14531303942203522
batch 85 complete. Loss:  0.0783238410949707
batch 86 complete. Loss:  0.10473880171775818
batch 87 complete. Loss:  0.03578078

batch 242 complete. Loss:  0.08807379007339478
batch 243 complete. Loss:  0.09311623871326447
batch 244 complete. Loss:  0.08161437511444092
batch 245 complete. Loss:  0.14735190570354462
batch 246 complete. Loss:  0.0707920640707016
batch 247 complete. Loss:  0.06722882390022278
batch 248 complete. Loss:  0.06686219573020935
batch 249 complete. Loss:  0.0944141075015068
batch 250 complete. Loss:  0.0755278468132019
batch 251 complete. Loss:  0.14918358623981476
batch 252 complete. Loss:  0.13118475675582886
batch 253 complete. Loss:  0.07471812516450882
batch 254 complete. Loss:  0.06456221640110016
batch 255 complete. Loss:  0.08308537304401398
batch 256 complete. Loss:  0.05365252494812012
batch 257 complete. Loss:  0.23417840898036957
batch 258 complete. Loss:  0.06600882112979889
batch 259 complete. Loss:  0.09317721426486969
batch 260 complete. Loss:  0.06825602054595947
batch 261 complete. Loss:  0.09892091155052185
batch 262 complete. Loss:  0.1052328497171402
batch 263 complet

batch 417 complete. Loss:  0.2684294581413269
batch 418 complete. Loss:  0.13419395685195923
batch 419 complete. Loss:  0.11566781997680664
batch 420 complete. Loss:  0.12602835893630981
batch 421 complete. Loss:  0.05726883187890053
batch 422 complete. Loss:  0.09540236741304398
batch 423 complete. Loss:  0.08717142790555954
batch 424 complete. Loss:  0.12757322192192078
batch 425 complete. Loss:  0.09940940141677856
batch 426 complete. Loss:  0.0733843594789505
batch 427 complete. Loss:  0.1460476517677307
batch 428 complete. Loss:  0.10810013115406036
batch 429 complete. Loss:  0.07182356715202332
batch 430 complete. Loss:  0.11468101292848587
batch 431 complete. Loss:  0.09786740690469742
batch 432 complete. Loss:  0.11425159871578217
batch 433 complete. Loss:  0.20156928896903992
batch 434 complete. Loss:  0.19018056988716125
batch 435 complete. Loss:  0.05454624071717262
batch 436 complete. Loss:  0.08136945962905884
batch 437 complete. Loss:  0.07985341548919678
batch 438 comple

batch 592 complete. Loss:  0.06522567570209503
batch 593 complete. Loss:  0.12383551150560379
batch 594 complete. Loss:  0.09156092256307602
batch 595 complete. Loss:  0.10231558978557587
batch 596 complete. Loss:  0.07852701097726822
batch 597 complete. Loss:  0.07018834352493286
batch 598 complete. Loss:  0.08010721206665039
batch 599 complete. Loss:  0.057731565088033676
batch 600 complete. Loss:  0.07779738306999207
batch 601 complete. Loss:  0.11136747151613235
batch 602 complete. Loss:  0.14442335069179535
batch 603 complete. Loss:  0.1619734913110733
batch 604 complete. Loss:  0.0806354284286499
batch 605 complete. Loss:  0.12109079957008362
batch 606 complete. Loss:  0.11701767891645432
batch 607 complete. Loss:  0.128215491771698
batch 608 complete. Loss:  0.11157921701669693
batch 609 complete. Loss:  0.06555428355932236
batch 610 complete. Loss:  0.11504891514778137
batch 611 complete. Loss:  0.12202811986207962
batch 612 complete. Loss:  0.1053919568657875
batch 613 complet

batch 93 complete. Loss:  0.037607915699481964
batch 94 complete. Loss:  0.035271160304546356
batch 95 complete. Loss:  0.06440618634223938
batch 96 complete. Loss:  0.07029242068529129
batch 97 complete. Loss:  0.07156956195831299
batch 98 complete. Loss:  0.04751601815223694
batch 99 complete. Loss:  0.07131589949131012
batch 100 complete. Loss:  0.030838072299957275
batch 101 complete. Loss:  0.07926388084888458
batch 102 complete. Loss:  0.08423209190368652
batch 103 complete. Loss:  0.05804945528507233
batch 104 complete. Loss:  0.07958559691905975
batch 105 complete. Loss:  0.09968545287847519
batch 106 complete. Loss:  0.1181524246931076
batch 107 complete. Loss:  0.044883258640766144
batch 108 complete. Loss:  0.04068541154265404
batch 109 complete. Loss:  0.07790213078260422
batch 110 complete. Loss:  0.09441977739334106
batch 111 complete. Loss:  0.07775460928678513
batch 112 complete. Loss:  0.028050709515810013
batch 113 complete. Loss:  0.07723948359489441
batch 114 comple

batch 268 complete. Loss:  0.09066753089427948
batch 269 complete. Loss:  0.07829020172357559
batch 270 complete. Loss:  0.049022745341062546
batch 271 complete. Loss:  0.056061021983623505
batch 272 complete. Loss:  0.059476546943187714
batch 273 complete. Loss:  0.08719806373119354
batch 274 complete. Loss:  0.05825509876012802
batch 275 complete. Loss:  0.05074404180049896
batch 276 complete. Loss:  0.033602274954319
batch 277 complete. Loss:  0.06399291753768921
batch 278 complete. Loss:  0.05387556180357933
batch 279 complete. Loss:  0.12409976124763489
batch 280 complete. Loss:  0.13029617071151733
batch 281 complete. Loss:  0.07108783721923828
batch 282 complete. Loss:  0.11023349314928055
batch 283 complete. Loss:  0.057840269058942795
batch 284 complete. Loss:  0.033764105290174484
batch 285 complete. Loss:  0.03419112041592598
batch 286 complete. Loss:  0.07504445314407349
batch 287 complete. Loss:  0.09615643322467804
batch 288 complete. Loss:  0.07154669612646103
batch 289 

batch 443 complete. Loss:  0.04351583495736122
batch 444 complete. Loss:  0.04432535171508789
batch 445 complete. Loss:  0.0646856278181076
batch 446 complete. Loss:  0.07040262222290039
batch 447 complete. Loss:  0.04896827042102814
batch 448 complete. Loss:  0.055887408554553986
batch 449 complete. Loss:  0.05233670026063919
batch 450 complete. Loss:  0.14926180243492126
batch 451 complete. Loss:  0.06601021438837051
batch 452 complete. Loss:  0.0727389007806778
batch 453 complete. Loss:  0.04064381495118141
batch 454 complete. Loss:  0.05379541963338852
batch 455 complete. Loss:  0.06768357753753662
batch 456 complete. Loss:  0.0698133185505867
batch 457 complete. Loss:  0.1176946759223938
batch 458 complete. Loss:  0.03332562744617462
batch 459 complete. Loss:  0.11385688185691833
batch 460 complete. Loss:  0.07295764982700348
batch 461 complete. Loss:  0.06932429224252701
batch 462 complete. Loss:  0.03507234901189804
batch 463 complete. Loss:  0.14320406317710876
batch 464 comple

batch 618 complete. Loss:  0.062379442155361176
batch 619 complete. Loss:  0.06872455775737762
batch 620 complete. Loss:  0.03420107439160347
batch 621 complete. Loss:  0.0518348291516304
batch 622 complete. Loss:  0.13477042317390442
batch 623 complete. Loss:  0.07752599567174911
batch 624 complete. Loss:  0.023676173761487007
batch 625 complete. Loss:  0.04217865318059921
batch 626 complete. Loss:  0.06516201794147491
batch 627 complete. Loss:  0.05595732107758522
batch 628 complete. Loss:  0.05365767329931259
batch 629 complete. Loss:  0.06531409919261932
batch 630 complete. Loss:  0.06334937363862991
batch 631 complete. Loss:  0.03801814094185829
batch 632 complete. Loss:  0.0459611713886261
batch 633 complete. Loss:  0.05839267373085022
batch 634 complete. Loss:  0.0388924665749073
batch 635 complete. Loss:  0.0715794488787651
batch 636 complete. Loss:  0.05163619667291641
batch 637 complete. Loss:  0.06274642050266266
batch 638 complete. Loss:  0.07509735226631165
batch 639 compl

batch 119 complete. Loss:  0.039374373853206635
batch 120 complete. Loss:  0.10401271283626556
batch 121 complete. Loss:  0.029490100219845772
batch 122 complete. Loss:  0.07692798972129822
batch 123 complete. Loss:  0.03886633366346359
batch 124 complete. Loss:  0.15150530636310577
batch 125 complete. Loss:  0.034418873488903046
batch 126 complete. Loss:  0.03936697542667389
batch 127 complete. Loss:  0.12684835493564606
batch 128 complete. Loss:  0.04495218023657799
batch 129 complete. Loss:  0.029726330190896988
batch 130 complete. Loss:  0.06244882941246033
batch 131 complete. Loss:  0.07510388642549515
batch 132 complete. Loss:  0.06827323138713837
batch 133 complete. Loss:  0.039934203028678894
batch 134 complete. Loss:  0.051679402589797974
batch 135 complete. Loss:  0.03035280480980873
batch 136 complete. Loss:  0.04681198298931122
batch 137 complete. Loss:  0.05629666522145271
batch 138 complete. Loss:  0.0648571252822876
batch 139 complete. Loss:  0.07910546660423279
batch 14

batch 293 complete. Loss:  0.03524428978562355
batch 294 complete. Loss:  0.03780083730816841
batch 295 complete. Loss:  0.07423228770494461
batch 296 complete. Loss:  0.09463278949260712
batch 297 complete. Loss:  0.05167313665151596
batch 298 complete. Loss:  0.039749257266521454
batch 299 complete. Loss:  0.0191054530441761
batch 300 complete. Loss:  0.06702262163162231
batch 301 complete. Loss:  0.01837606355547905
batch 302 complete. Loss:  0.032004162669181824
batch 303 complete. Loss:  0.10710892081260681
batch 304 complete. Loss:  0.04348495230078697
batch 305 complete. Loss:  0.06587731838226318
batch 306 complete. Loss:  0.06054847687482834
batch 307 complete. Loss:  0.04319251701235771
batch 308 complete. Loss:  0.05169905722141266
batch 309 complete. Loss:  0.03922798112034798
batch 310 complete. Loss:  0.0524991899728775
batch 311 complete. Loss:  0.020032204687595367
batch 312 complete. Loss:  0.04218675568699837
batch 313 complete. Loss:  0.04696369916200638
batch 314 co

batch 467 complete. Loss:  0.04394589364528656
batch 468 complete. Loss:  0.04972399026155472
batch 469 complete. Loss:  0.10322847962379456
batch 470 complete. Loss:  0.026997698470950127
batch 471 complete. Loss:  0.04129462316632271
batch 472 complete. Loss:  0.02679593861103058
batch 473 complete. Loss:  0.09016990661621094
batch 474 complete. Loss:  0.041429609060287476
batch 475 complete. Loss:  0.03235528618097305
batch 476 complete. Loss:  0.032820574939250946
batch 477 complete. Loss:  0.03747983276844025
batch 478 complete. Loss:  0.055509235709905624
batch 479 complete. Loss:  0.05586274713277817
batch 480 complete. Loss:  0.057237472385168076
batch 481 complete. Loss:  0.06047193706035614
batch 482 complete. Loss:  0.04445023834705353
batch 483 complete. Loss:  0.03946419805288315
batch 484 complete. Loss:  0.078096903860569
batch 485 complete. Loss:  0.069475457072258
batch 486 complete. Loss:  0.04150668531656265
batch 487 complete. Loss:  0.0905393585562706
batch 488 com

batch 641 complete. Loss:  0.021422434598207474
batch 642 complete. Loss:  0.04958850145339966
batch 643 complete. Loss:  0.044435400515794754
batch 644 complete. Loss:  0.031311407685279846
batch 645 complete. Loss:  0.03616630285978317
batch 646 complete. Loss:  0.06202509254217148
batch 647 complete. Loss:  0.13168080151081085
batch 648 complete. Loss:  0.019508518278598785
batch 649 complete. Loss:  0.03429427742958069
batch 650 complete. Loss:  0.08790234476327896
batch 651 complete. Loss:  0.05667068809270859
batch 652 complete. Loss:  0.06605293601751328
batch 653 complete. Loss:  0.04575273394584656
batch 654 complete. Loss:  0.06961572170257568
batch 655 complete. Loss:  0.0605425089597702
batch 656 complete. Loss:  0.05181507393717766
batch 657 complete. Loss:  0.07030832767486572
batch 658 complete. Loss:  0.07448258250951767
batch 659 complete. Loss:  0.027884159237146378
batch 660 complete. Loss:  0.08230668306350708
batch 661 complete. Loss:  0.07773493230342865
batch 662

batch 141 complete. Loss:  0.04474598541855812
batch 142 complete. Loss:  0.050043001770973206
batch 143 complete. Loss:  0.04332298785448074
batch 144 complete. Loss:  0.027146317064762115
batch 145 complete. Loss:  0.022363625466823578
batch 146 complete. Loss:  0.03963640332221985
batch 147 complete. Loss:  0.07060390710830688
batch 148 complete. Loss:  0.0704805850982666
batch 149 complete. Loss:  0.030146287754178047
batch 150 complete. Loss:  0.048348285257816315
batch 151 complete. Loss:  0.06310677528381348
batch 152 complete. Loss:  0.062303632497787476
batch 153 complete. Loss:  0.039327025413513184
batch 154 complete. Loss:  0.02929057739675045
batch 155 complete. Loss:  0.04545127972960472
batch 156 complete. Loss:  0.07075761258602142
batch 157 complete. Loss:  0.04020293056964874
batch 158 complete. Loss:  0.040239013731479645
batch 159 complete. Loss:  0.030804378911852837
batch 160 complete. Loss:  0.03204571455717087
batch 161 complete. Loss:  0.03983473777770996
batch

batch 315 complete. Loss:  0.025449376553297043
batch 316 complete. Loss:  0.015898019075393677
batch 317 complete. Loss:  0.01739831641316414
batch 318 complete. Loss:  0.02558708004653454
batch 319 complete. Loss:  0.032590966671705246
batch 320 complete. Loss:  0.04124421626329422
batch 321 complete. Loss:  0.029522273689508438
batch 322 complete. Loss:  0.04291978105902672
batch 323 complete. Loss:  0.03075643628835678
batch 324 complete. Loss:  0.015020238235592842
batch 325 complete. Loss:  0.03285729140043259
batch 326 complete. Loss:  0.031211351975798607
batch 327 complete. Loss:  0.04847639799118042
batch 328 complete. Loss:  0.05358142778277397
batch 329 complete. Loss:  0.047288455069065094
batch 330 complete. Loss:  0.1025332435965538
batch 331 complete. Loss:  0.022304989397525787
batch 332 complete. Loss:  0.04075721651315689
batch 333 complete. Loss:  0.03875170275568962
batch 334 complete. Loss:  0.041719451546669006
batch 335 complete. Loss:  0.02008046768605709
batch

batch 489 complete. Loss:  0.04083135724067688
batch 490 complete. Loss:  0.032942354679107666
batch 491 complete. Loss:  0.03179004043340683
batch 492 complete. Loss:  0.03527680039405823
batch 493 complete. Loss:  0.016565388068556786
batch 494 complete. Loss:  0.08472692966461182
batch 495 complete. Loss:  0.05199740082025528
batch 496 complete. Loss:  0.02737889066338539
batch 497 complete. Loss:  0.04210679605603218
batch 498 complete. Loss:  0.03981637954711914
batch 499 complete. Loss:  0.03178190439939499
batch 500 complete. Loss:  0.045843180269002914
batch 501 complete. Loss:  0.029750235378742218
batch 502 complete. Loss:  0.05774549022316933
batch 503 complete. Loss:  0.05487268045544624
batch 504 complete. Loss:  0.02558821626007557
batch 505 complete. Loss:  0.04181181639432907
batch 506 complete. Loss:  0.02959371730685234
batch 507 complete. Loss:  0.03150492161512375
batch 508 complete. Loss:  0.028099972754716873
batch 509 complete. Loss:  0.03048587404191494
batch 51

batch 663 complete. Loss:  0.025568678975105286
batch 664 complete. Loss:  0.029359476640820503
batch 665 complete. Loss:  0.029953472316265106
batch 666 complete. Loss:  0.04386395215988159
batch 667 complete. Loss:  0.02356957271695137
batch 668 complete. Loss:  0.02941734530031681
batch 669 complete. Loss:  0.022707758471369743
batch 670 complete. Loss:  0.012644760310649872
batch 671 complete. Loss:  0.043548643589019775
batch 672 complete. Loss:  0.05798732861876488
batch 673 complete. Loss:  0.0790824219584465
batch 674 complete. Loss:  0.029748156666755676
Epoch 5 Validation Loss: 0.08509285335798235
6 / 10
batch 0 complete. Loss:  0.03071102313697338
batch 1 complete. Loss:  0.022545769810676575
batch 2 complete. Loss:  0.03320375829935074
batch 3 complete. Loss:  0.0257123876363039
batch 4 complete. Loss:  0.01833181641995907
batch 5 complete. Loss:  0.02266438864171505
batch 6 complete. Loss:  0.044694580137729645
batch 7 complete. Loss:  0.06725070625543594
batch 8 complete.

batch 163 complete. Loss:  0.030005620792508125
batch 164 complete. Loss:  0.04189874976873398
batch 165 complete. Loss:  0.033970557153224945
batch 166 complete. Loss:  0.023955892771482468
batch 167 complete. Loss:  0.026502255350351334
batch 168 complete. Loss:  0.015613672323524952
batch 169 complete. Loss:  0.03009020909667015
batch 170 complete. Loss:  0.02104872837662697
batch 171 complete. Loss:  0.021998893469572067
batch 172 complete. Loss:  0.019267655909061432
batch 173 complete. Loss:  0.028887202963232994
batch 174 complete. Loss:  0.01940854638814926
batch 175 complete. Loss:  0.025907661765813828
batch 176 complete. Loss:  0.013574306853115559
batch 177 complete. Loss:  0.0335233099758625
batch 178 complete. Loss:  0.049781158566474915
batch 179 complete. Loss:  0.027938958257436752
batch 180 complete. Loss:  0.031212851405143738
batch 181 complete. Loss:  0.08756347000598907
batch 182 complete. Loss:  0.018275177106261253
batch 183 complete. Loss:  0.02915719524025917


batch 336 complete. Loss:  0.03241429477930069
batch 337 complete. Loss:  0.017600033432245255
batch 338 complete. Loss:  0.019046800211071968
batch 339 complete. Loss:  0.01765976846218109
batch 340 complete. Loss:  0.01754995435476303
batch 341 complete. Loss:  0.014619926922023296
batch 342 complete. Loss:  0.01773984171450138
batch 343 complete. Loss:  0.022684112191200256
batch 344 complete. Loss:  0.017853688448667526
batch 345 complete. Loss:  0.023837421089410782
batch 346 complete. Loss:  0.012018844485282898
batch 347 complete. Loss:  0.035701632499694824
batch 348 complete. Loss:  0.019564572721719742
batch 349 complete. Loss:  0.023388739675283432
batch 350 complete. Loss:  0.023948565125465393
batch 351 complete. Loss:  0.028652429580688477
batch 352 complete. Loss:  0.02480814978480339
batch 353 complete. Loss:  0.02415318414568901
batch 354 complete. Loss:  0.033478621393442154
batch 355 complete. Loss:  0.032367344945669174
batch 356 complete. Loss:  0.02162762172520160

batch 509 complete. Loss:  0.013097809627652168
batch 510 complete. Loss:  0.024111472070217133
batch 511 complete. Loss:  0.03338801488280296
batch 512 complete. Loss:  0.02564728818833828
batch 513 complete. Loss:  0.02074911817908287
batch 514 complete. Loss:  0.02152187190949917
batch 515 complete. Loss:  0.025328129529953003
batch 516 complete. Loss:  0.041378844529390335
batch 517 complete. Loss:  0.10320569574832916
batch 518 complete. Loss:  0.05963990464806557
batch 519 complete. Loss:  0.025724051520228386
batch 520 complete. Loss:  0.029043545946478844
batch 521 complete. Loss:  0.09141027182340622
batch 522 complete. Loss:  0.03808911144733429
batch 523 complete. Loss:  0.02658633515238762
batch 524 complete. Loss:  0.09278438985347748
batch 525 complete. Loss:  0.05691683664917946
batch 526 complete. Loss:  0.021270066499710083
batch 527 complete. Loss:  0.08265508711338043
batch 528 complete. Loss:  0.06713160872459412
batch 529 complete. Loss:  0.03727264702320099
batch 

batch 6 complete. Loss:  0.01774699240922928
batch 7 complete. Loss:  0.011570021510124207
batch 8 complete. Loss:  0.020867502316832542
batch 9 complete. Loss:  0.033928703516721725
batch 10 complete. Loss:  0.03115016594529152
batch 11 complete. Loss:  0.03570173680782318
batch 12 complete. Loss:  0.021074095740914345
batch 13 complete. Loss:  0.02220735512673855
batch 14 complete. Loss:  0.011634377762675285
batch 15 complete. Loss:  0.06588225066661835
batch 16 complete. Loss:  0.015466781333088875
batch 17 complete. Loss:  0.022801604121923447
batch 18 complete. Loss:  0.02239997871220112
batch 19 complete. Loss:  0.02144012413918972
batch 20 complete. Loss:  0.016528476029634476
batch 21 complete. Loss:  0.028544608503580093
batch 22 complete. Loss:  0.013999473303556442
batch 23 complete. Loss:  0.017582479864358902
batch 24 complete. Loss:  0.034766390919685364
batch 25 complete. Loss:  0.021987762302160263
batch 26 complete. Loss:  0.026656359434127808
batch 27 complete. Loss:

batch 181 complete. Loss:  0.019195500761270523
batch 182 complete. Loss:  0.08910372853279114
batch 183 complete. Loss:  0.07157435268163681
batch 184 complete. Loss:  0.022264841943979263
batch 185 complete. Loss:  0.01663903519511223
batch 186 complete. Loss:  0.013817119412124157
batch 187 complete. Loss:  0.030764950439333916
batch 188 complete. Loss:  0.024432379752397537
batch 189 complete. Loss:  0.029178794473409653
batch 190 complete. Loss:  0.016646448522806168
batch 191 complete. Loss:  0.020639777183532715
batch 192 complete. Loss:  0.04555830359458923
batch 193 complete. Loss:  0.022617504000663757
batch 194 complete. Loss:  0.01550272572785616
batch 195 complete. Loss:  0.015977229923009872
batch 196 complete. Loss:  0.019856223836541176
batch 197 complete. Loss:  0.028174590319395065
batch 198 complete. Loss:  0.027691910043358803
batch 199 complete. Loss:  0.022971797734498978
batch 200 complete. Loss:  0.01584327034652233
batch 201 complete. Loss:  0.01940581575036049

batch 354 complete. Loss:  0.032782625406980515
batch 355 complete. Loss:  0.02832995355129242
batch 356 complete. Loss:  0.024798259139060974
batch 357 complete. Loss:  0.01873948611319065
batch 358 complete. Loss:  0.026042625308036804
batch 359 complete. Loss:  0.028648819774389267
batch 360 complete. Loss:  0.0351032018661499
batch 361 complete. Loss:  0.07504221796989441
batch 362 complete. Loss:  0.053852636367082596
batch 363 complete. Loss:  0.012038737535476685
batch 364 complete. Loss:  0.010454870760440826
batch 365 complete. Loss:  0.026806462556123734
batch 366 complete. Loss:  0.013660338707268238
batch 367 complete. Loss:  0.031844042241573334
batch 368 complete. Loss:  0.012471869587898254
batch 369 complete. Loss:  0.025263674557209015
batch 370 complete. Loss:  0.019406002014875412
batch 371 complete. Loss:  0.05038627237081528
batch 372 complete. Loss:  0.08040334284305573
batch 373 complete. Loss:  0.026737118139863014
batch 374 complete. Loss:  0.019943254068493843

batch 527 complete. Loss:  0.08279125392436981
batch 528 complete. Loss:  0.03137415274977684
batch 529 complete. Loss:  0.025958213955163956
batch 530 complete. Loss:  0.03377465531229973
batch 531 complete. Loss:  0.04333442449569702
batch 532 complete. Loss:  0.03143757954239845
batch 533 complete. Loss:  0.02738676592707634
batch 534 complete. Loss:  0.0156564824283123
batch 535 complete. Loss:  0.023387577384710312
batch 536 complete. Loss:  0.028023039922118187
batch 537 complete. Loss:  0.029802465811371803
batch 538 complete. Loss:  0.027432376518845558
batch 539 complete. Loss:  0.015324774198234081
batch 540 complete. Loss:  0.0255103912204504
batch 541 complete. Loss:  0.016647953540086746
batch 542 complete. Loss:  0.056784145534038544
batch 543 complete. Loss:  0.0254665520042181
batch 544 complete. Loss:  0.02913171797990799
batch 545 complete. Loss:  0.041065189987421036
batch 546 complete. Loss:  0.014018239453434944
batch 547 complete. Loss:  0.015401182696223259
batch

batch 25 complete. Loss:  0.04675748944282532
batch 26 complete. Loss:  0.020644711330533028
batch 27 complete. Loss:  0.018635833635926247
batch 28 complete. Loss:  0.02501242235302925
batch 29 complete. Loss:  0.01414906419813633
batch 30 complete. Loss:  0.017175914719700813
batch 31 complete. Loss:  0.012842555530369282
batch 32 complete. Loss:  0.02498454600572586
batch 33 complete. Loss:  0.01914353482425213
batch 34 complete. Loss:  0.026240380480885506
batch 35 complete. Loss:  0.010543476790189743
batch 36 complete. Loss:  0.02025396004319191
batch 37 complete. Loss:  0.021527495235204697
batch 38 complete. Loss:  0.010715119540691376
batch 39 complete. Loss:  0.015984883531928062
batch 40 complete. Loss:  0.015766337513923645
batch 41 complete. Loss:  0.019159339368343353
batch 42 complete. Loss:  0.014187546446919441
batch 43 complete. Loss:  0.02279421128332615
batch 44 complete. Loss:  0.020698312669992447
batch 45 complete. Loss:  0.02029728703200817
batch 46 complete. Lo

KeyboardInterrupt: 

In [21]:
idx =101

#text=scaled_data['text'][idx]
text='''East Lansing Public Schools, home of Michigan State University (MSU) and a neighbor to the state capital, seeks a
dynamic individual to work in our highly diverse school district. This individual must love kids, have a passion for
their success, be a collaborative team player, deeply understand best practices for teaching and learning, and
engage all students in the learning process. Work in a district that has built six new elementary schools with stateof-the-art instructional and sustainable enhancements, has multiple partnerships with MSU, encourages
continuous professional growth and learning, and is highly supported by its families and community!

POSITION: Summer School Administrator (shared position)
LOCATION: Donley and Red Cedar Elementary Schools
SALARY: Stipend: $3,500
WORK YEAR: Weeks of June 24, July 8, July 15, Plus July 22 and July 25 - Work Days: M, T, W, TH 8:30 AM to
1:30 PM
START DATE: June 24, 2024
REPORTS TO: Assistant Superintendent
APPLICATION DEADLINE: March 15, 2024


GENERAL SUMMARY: Under the direct supervision of the Assistant Superintendent, and in accordance with the
established policies and procedures of the East Lansing Public School District, the Summer School Administrator is
responsible for the administration and oversight of the elementary summer school program. This person will be
sharing the administration with an elementary principal and working the weeks/days the elementary principal is not
working. The elementary summer school program will likely be run at two sites: Donley and Red Cedar, so travel
between sites will be necessary.'''
data = torch.tensor(nn_data[idx]).unsqueeze(0)

# inference
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True).to('cuda')
outputs = model(**inputs, nn_data=data)
pred = scaler.inverse_transform(np.array([outputs.item()]))
print(pred)
#print(pred, scaler.inverse_transform(np.array([scaled_data['label'][idx]])))



[42675.0019354]


In [22]:
train_losses, val_losses

([0.40143253423549513,
  0.16957210713514576,
  0.1047804256418237,
  0.07134516571406965,
  0.05439718822362246,
  0.041425507995817394,
  0.034718237422682624,
  0.029447722020386546],
 [0.2210999514312434,
  0.15841861827486364,
  0.11680733535945768,
  0.105252427616592,
  0.10249069069660979,
  0.08509285335798235,
  0.07618940602877788,
  0.08059447902370487])

In [23]:
train_losses = []
with torch.no_grad():  # Disable gradient calculation for validation
    for batch in train_dataloader:
        input_ids = batch['input_ids'].to('cuda') 
        attention_mask = batch['attention_mask'].to('cuda')
        labels = batch['labels'].to('cuda')
        data = batch['nn_data'].to('cuda')
        for i in range(len(labels)):
            output = model(input_ids[i].unsqueeze(0), 
                           attention_mask=attention_mask[i].unsqueeze(0),
                           nn_data = data[i].unsqueeze(0))
            yhat = scaler.inverse_transform(np.array([output.squeeze(1).item()]))
            y=scaler.inverse_transform(np.array([labels[i].unsqueeze(0).item()]))
            loss = np.abs(yhat-y)
            
            train_losses.append(loss)
            #print(loss)


In [24]:
test_losses = []
with torch.no_grad():  # Disable gradient calculation for validation
    for batch in val_dataloader:
        input_ids = batch['input_ids'].to('cuda') 
        attention_mask = batch['attention_mask'].to('cuda')
        labels = batch['labels'].to('cuda')
        data = batch['nn_data'].to('cuda')
        for i in range(len(labels)):
            output = model(input_ids[i].unsqueeze(0), 
                           attention_mask=attention_mask[i].unsqueeze(0),
                           nn_data = data[i].unsqueeze(0))
            yhat = scaler.inverse_transform(np.array([output.squeeze(1).item()]))
            y=scaler.inverse_transform(np.array([labels[i].unsqueeze(0).item()]))
            loss = np.abs(yhat-y)
            
            test_losses.append(loss)
            
        


In [26]:
print("Mean training err: $" +  str(np.mean(train_losses)))
print("Median training err: $" + str(np.median(train_losses)))
print("Mean test err: $" + str(np.mean(test_losses)))
print("Median test err: $" + str(np.median(test_losses)))

Mean training err: $6472.70942511934
Median training err: $4943.678810362908
Mean test err: $9306.064277280737
Median test err: $5682.710841949112


In [None]:
#del model
#torch.cuda.empty_cache()