In [1]:
import torch
from transformers import BertTokenizer

from data.dataset import CustomDataset
from torch.utils.data import DataLoader
device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

data_loader=DataLoader(CustomDataset(tokenize=True), batch_size=32, shuffle=True)

In [2]:
# number of examples 
print(f"number of examples are: {len(data_loader.dataset)}")

number of examples are: 4012


In [3]:
# sample question-answer examples
import random

Q_A_data=CustomDataset().dataset
random_idx=random.randint(0,len(Q_A_data)-1)
print(f"Question\n {Q_A_data[random_idx]['question']}")
print(f"Answer\n {Q_A_data[random_idx]['answer']}")

Question
 What percentage of Homo sapiens DNA is of Neanderthal origin?
Answer
 We find that the power to reject ancient admixture might be particularly low if the population size of Homo sapiens was comparable to the Neanderthal population size. Our results indicate that 3.6% of the Neanderthal genome is shared with roughly 65.4% of the average European gene pool, which clinally diminishes with distance from Europe.


## Training

In [4]:
from scripts.model import DualEncoder
embed_dim=512
output_dim=128
epochs=10
tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')
model = DualEncoder(tokenizer,embed_dim,output_dim).to(device)

2024-09-11 19:40:56,215 - root - INFO - Dual encoder model for Q&A successfully initialized with:

 embedding dimension:512 
 output_dim:128 


In [6]:
from scripts.train import train
import time 
start_time = time.time()
train(model=model,device=device,epochs=10,data_loader=data_loader)
end_time = time.time()
print(f"Time taken for training: {end_time-start_time}")

2024-09-11 19:42:44,062 - root - INFO - Training begins for 10 epochs
2024-09-11 19:42:44,062 - root - INFO - Epoch 1/10


  0%|          | 0/126 [00:00<?, ?it/s]

2024-09-11 19:43:00,941 - my_module - INFO - Epoch 1, loss = 3.8397176265716553


  8%|▊         | 10/126 [02:11<25:18, 13.09s/it]

2024-09-11 19:45:06,979 - my_module - INFO - Epoch 1, loss = 3.7976727052168413


 16%|█▌        | 20/126 [05:41<33:26, 18.93s/it]

2024-09-11 19:48:39,716 - my_module - INFO - Epoch 1, loss = 3.7733155886332193


 24%|██▍       | 30/126 [07:45<19:55, 12.45s/it]

2024-09-11 19:50:42,093 - my_module - INFO - Epoch 1, loss = 3.7299134962020384


 32%|███▏      | 40/126 [10:03<20:35, 14.37s/it]

2024-09-11 19:53:00,863 - my_module - INFO - Epoch 1, loss = 3.6948631216840044


 40%|███▉      | 50/126 [12:07<14:53, 11.75s/it]

2024-09-11 19:55:02,822 - my_module - INFO - Epoch 1, loss = 3.6677115758260093


 48%|████▊     | 60/126 [14:01<12:33, 11.42s/it]

2024-09-11 19:56:57,587 - my_module - INFO - Epoch 1, loss = 3.6460029531697757


 56%|█████▌    | 70/126 [15:56<10:57, 11.74s/it]

2024-09-11 19:58:52,741 - my_module - INFO - Epoch 1, loss = 3.6304362021701437


 63%|██████▎   | 80/126 [17:56<09:06, 11.89s/it]

2024-09-11 20:00:51,913 - my_module - INFO - Epoch 1, loss = 3.6156619625327027


 71%|███████▏  | 90/126 [19:53<06:59, 11.65s/it]

2024-09-11 20:02:49,273 - my_module - INFO - Epoch 1, loss = 3.6050542370303646


 79%|███████▉  | 100/126 [21:49<05:02, 11.65s/it]

2024-09-11 20:04:45,243 - my_module - INFO - Epoch 1, loss = 3.5904733143230474


 87%|████████▋ | 110/126 [23:45<03:05, 11.62s/it]

2024-09-11 20:06:41,397 - my_module - INFO - Epoch 1, loss = 3.5787845452626548


 95%|█████████▌| 120/126 [25:39<01:07, 11.29s/it]

2024-09-11 20:08:35,378 - my_module - INFO - Epoch 1, loss = 3.570331437528626


100%|██████████| 126/126 [26:41<00:00, 12.71s/it]

2024-09-11 20:09:25,158 - root - INFO - Epoch 2/10



  0%|          | 0/126 [00:00<?, ?it/s]

2024-09-11 20:09:36,352 - my_module - INFO - Epoch 2, loss = 3.4430994987487793


  8%|▊         | 10/126 [01:54<23:01, 11.91s/it]

2024-09-11 20:11:32,134 - my_module - INFO - Epoch 2, loss = 3.420604012229226


 16%|█▌        | 20/126 [03:57<21:33, 12.21s/it]

2024-09-11 20:13:34,485 - my_module - INFO - Epoch 2, loss = 3.410097564969744


 24%|██▍       | 30/126 [06:00<19:38, 12.28s/it]

2024-09-11 20:15:37,644 - my_module - INFO - Epoch 2, loss = 3.4114007488373788


 32%|███▏      | 40/126 [08:02<17:30, 12.22s/it]

2024-09-11 20:17:39,613 - my_module - INFO - Epoch 2, loss = 3.4034555365399615


 40%|███▉      | 50/126 [10:04<15:25, 12.18s/it]

2024-09-11 20:19:41,378 - my_module - INFO - Epoch 2, loss = 3.401236978231692


 48%|████▊     | 60/126 [12:06<13:29, 12.26s/it]

2024-09-11 20:21:43,948 - my_module - INFO - Epoch 2, loss = 3.3944674984353487


 56%|█████▌    | 70/126 [14:09<11:24, 12.23s/it]

2024-09-11 20:23:46,707 - my_module - INFO - Epoch 2, loss = 3.390174543353873


 63%|██████▎   | 80/126 [16:10<09:19, 12.17s/it]

2024-09-11 20:25:47,464 - my_module - INFO - Epoch 2, loss = 3.385257476641808


 71%|███████▏  | 90/126 [18:12<07:20, 12.23s/it]

2024-09-11 20:27:50,253 - my_module - INFO - Epoch 2, loss = 3.3804932557619534


 79%|███████▉  | 100/126 [20:15<05:17, 12.21s/it]

2024-09-11 20:29:52,604 - my_module - INFO - Epoch 2, loss = 3.3777719511844144


 87%|████████▋ | 110/126 [22:17<03:15, 12.21s/it]

2024-09-11 20:31:54,765 - my_module - INFO - Epoch 2, loss = 3.3720040385787553


 95%|█████████▌| 120/126 [24:18<01:12, 12.09s/it]

2024-09-11 20:33:55,725 - my_module - INFO - Epoch 2, loss = 3.3669453100724653


100%|██████████| 126/126 [25:20<00:00, 12.07s/it]

2024-09-11 20:34:46,125 - root - INFO - Epoch 3/10



  0%|          | 0/126 [00:00<?, ?it/s]

2024-09-11 20:35:00,708 - my_module - INFO - Epoch 3, loss = 3.3369243144989014


  8%|▊         | 10/126 [02:04<23:40, 12.25s/it]

2024-09-11 20:37:03,086 - my_module - INFO - Epoch 3, loss = 3.244333180514249


 16%|█▌        | 20/126 [04:06<21:32, 12.19s/it]

2024-09-11 20:39:05,836 - my_module - INFO - Epoch 3, loss = 3.235588244029454


 24%|██▍       | 30/126 [06:10<19:37, 12.26s/it]

2024-09-11 20:41:08,402 - my_module - INFO - Epoch 3, loss = 3.238188882027903


 32%|███▏      | 40/126 [08:16<18:09, 12.66s/it]

2024-09-11 20:43:14,783 - my_module - INFO - Epoch 3, loss = 3.240501671302609


 40%|███▉      | 50/126 [10:19<15:37, 12.34s/it]

2024-09-11 20:45:17,741 - my_module - INFO - Epoch 3, loss = 3.2380758266822967


 48%|████▊     | 60/126 [12:22<13:36, 12.37s/it]

2024-09-11 20:47:21,293 - my_module - INFO - Epoch 3, loss = 3.2295176044839327


 56%|█████▌    | 70/126 [14:24<11:19, 12.14s/it]

2024-09-11 20:49:23,261 - my_module - INFO - Epoch 3, loss = 3.218511685519151


 63%|██████▎   | 80/126 [16:27<09:26, 12.32s/it]

2024-09-11 20:51:25,832 - my_module - INFO - Epoch 3, loss = 3.207914196414712


 71%|███████▏  | 90/126 [18:35<07:40, 12.78s/it]

2024-09-11 20:53:33,773 - my_module - INFO - Epoch 3, loss = 3.1964727889050493


 79%|███████▉  | 100/126 [20:38<05:18, 12.25s/it]

2024-09-11 20:55:36,348 - my_module - INFO - Epoch 3, loss = 3.1905959596728333


 87%|████████▋ | 110/126 [22:39<03:12, 12.06s/it]

2024-09-11 20:57:37,319 - my_module - INFO - Epoch 3, loss = 3.188027285240792


 95%|█████████▌| 120/126 [24:44<01:18, 13.03s/it]

2024-09-11 20:59:42,662 - my_module - INFO - Epoch 3, loss = 3.1824624183749366


100%|██████████| 126/126 [25:50<00:00, 12.31s/it]

2024-09-11 21:00:36,854 - root - INFO - Epoch 4/10



  0%|          | 0/126 [00:00<?, ?it/s]

2024-09-11 21:00:49,053 - my_module - INFO - Epoch 4, loss = 3.144430160522461


  8%|▊         | 10/126 [02:01<23:28, 12.14s/it]

2024-09-11 21:02:50,238 - my_module - INFO - Epoch 4, loss = 2.964867938648571


 16%|█▌        | 20/126 [04:02<21:23, 12.11s/it]

2024-09-11 21:04:51,211 - my_module - INFO - Epoch 4, loss = 2.953532797949655


 24%|██▍       | 30/126 [06:05<20:21, 12.72s/it]

2024-09-11 21:06:54,767 - my_module - INFO - Epoch 4, loss = 2.9628697595288678


 32%|███▏      | 40/126 [08:08<17:34, 12.26s/it]

2024-09-11 21:08:57,129 - my_module - INFO - Epoch 4, loss = 2.9555604225251733


 40%|███▉      | 50/126 [10:09<15:24, 12.16s/it]

2024-09-11 21:10:58,472 - my_module - INFO - Epoch 4, loss = 2.9344350450179157


 48%|████▊     | 60/126 [12:12<13:33, 12.32s/it]

2024-09-11 21:13:01,850 - my_module - INFO - Epoch 4, loss = 2.9265919435219687


 56%|█████▌    | 70/126 [14:14<11:22, 12.19s/it]

2024-09-11 21:15:03,223 - my_module - INFO - Epoch 4, loss = 2.9042077366734893


 63%|██████▎   | 80/126 [16:17<09:34, 12.50s/it]

2024-09-11 21:17:06,979 - my_module - INFO - Epoch 4, loss = 2.892471757935889


 71%|███████▏  | 90/126 [18:22<07:26, 12.39s/it]

2024-09-11 21:19:11,147 - my_module - INFO - Epoch 4, loss = 2.8813416014660844


 79%|███████▉  | 100/126 [20:24<05:20, 12.31s/it]

2024-09-11 21:21:13,892 - my_module - INFO - Epoch 4, loss = 2.870108710657252


 87%|████████▋ | 110/126 [22:28<03:17, 12.32s/it]

2024-09-11 21:23:17,065 - my_module - INFO - Epoch 4, loss = 2.8609658576346733


 95%|█████████▌| 120/126 [24:31<01:13, 12.32s/it]

2024-09-11 21:25:20,243 - my_module - INFO - Epoch 4, loss = 2.8495219325231127


100%|██████████| 126/126 [25:37<00:00, 12.20s/it]

2024-09-11 21:26:14,628 - root - INFO - Epoch 5/10



  0%|          | 0/126 [00:00<?, ?it/s]

2024-09-11 21:26:27,021 - my_module - INFO - Epoch 5, loss = 2.969083309173584


  8%|▊         | 10/126 [02:03<23:41, 12.26s/it]

2024-09-11 21:28:30,184 - my_module - INFO - Epoch 5, loss = 2.6882651068947534


 16%|█▌        | 20/126 [04:03<20:35, 11.65s/it]

2024-09-11 21:30:28,948 - my_module - INFO - Epoch 5, loss = 2.597092344647362


 24%|██▍       | 30/126 [05:55<17:58, 11.23s/it]

2024-09-11 21:32:21,321 - my_module - INFO - Epoch 5, loss = 2.6123102480365383


 32%|███▏      | 40/126 [07:48<16:06, 11.24s/it]

2024-09-11 21:34:13,858 - my_module - INFO - Epoch 5, loss = 2.5808562069404415


 40%|███▉      | 50/126 [09:41<14:17, 11.28s/it]

2024-09-11 21:36:06,845 - my_module - INFO - Epoch 5, loss = 2.5746518957848643


 48%|████▊     | 60/126 [11:33<12:26, 11.31s/it]

2024-09-11 21:37:59,813 - my_module - INFO - Epoch 5, loss = 2.5732937289066


 56%|█████▌    | 70/126 [13:26<10:28, 11.23s/it]

2024-09-11 21:39:51,983 - my_module - INFO - Epoch 5, loss = 2.5642923670755304


 63%|██████▎   | 80/126 [15:18<08:42, 11.35s/it]

2024-09-11 21:41:44,744 - my_module - INFO - Epoch 5, loss = 2.5404153517734858


 71%|███████▏  | 90/126 [17:11<06:44, 11.24s/it]

2024-09-11 21:43:36,918 - my_module - INFO - Epoch 5, loss = 2.532392850288978


 79%|███████▉  | 100/126 [19:03<04:54, 11.33s/it]

2024-09-11 21:45:29,495 - my_module - INFO - Epoch 5, loss = 2.5349235817937568


 87%|████████▋ | 110/126 [20:56<03:01, 11.35s/it]

2024-09-11 21:47:22,445 - my_module - INFO - Epoch 5, loss = 2.5244047770629057


 95%|█████████▌| 120/126 [22:48<01:07, 11.20s/it]

2024-09-11 21:49:14,608 - my_module - INFO - Epoch 5, loss = 2.5220438212402594


100%|██████████| 126/126 [23:49<00:00, 11.34s/it]

2024-09-11 21:50:03,798 - root - INFO - Epoch 6/10



  0%|          | 0/126 [00:00<?, ?it/s]

2024-09-11 21:50:14,790 - my_module - INFO - Epoch 6, loss = 2.3685142993927


  8%|▊         | 10/126 [01:50<21:23, 11.07s/it]

2024-09-11 21:52:05,558 - my_module - INFO - Epoch 6, loss = 2.2543487548828125


 16%|█▌        | 20/126 [03:41<19:29, 11.03s/it]

2024-09-11 21:53:55,941 - my_module - INFO - Epoch 6, loss = 2.2156120595477877


 24%|██▍       | 30/126 [05:33<17:59, 11.25s/it]

2024-09-11 21:55:48,301 - my_module - INFO - Epoch 6, loss = 2.2269407433848225


 32%|███▏      | 40/126 [07:25<16:11, 11.30s/it]

2024-09-11 21:57:42,478 - my_module - INFO - Epoch 6, loss = 2.228595189931916


 40%|███▉      | 50/126 [09:32<15:41, 12.39s/it]

2024-09-11 21:59:49,030 - my_module - INFO - Epoch 6, loss = 2.244016605265


 48%|████▊     | 60/126 [11:35<13:27, 12.23s/it]

2024-09-11 22:01:51,396 - my_module - INFO - Epoch 6, loss = 2.2414921463512982


 56%|█████▌    | 70/126 [13:37<11:14, 12.05s/it]

2024-09-11 22:03:52,166 - my_module - INFO - Epoch 6, loss = 2.2526614565244865


 63%|██████▎   | 80/126 [15:29<08:40, 11.32s/it]

2024-09-11 22:05:44,332 - my_module - INFO - Epoch 6, loss = 2.248474519929768


 71%|███████▏  | 90/126 [17:20<06:41, 11.15s/it]

2024-09-11 22:07:35,909 - my_module - INFO - Epoch 6, loss = 2.246065351989243


 79%|███████▉  | 100/126 [19:14<04:56, 11.41s/it]

2024-09-11 22:09:30,062 - my_module - INFO - Epoch 6, loss = 2.2379756469537715


 87%|████████▋ | 110/126 [21:09<03:02, 11.43s/it]

2024-09-11 22:11:24,647 - my_module - INFO - Epoch 6, loss = 2.2334581720936404


 95%|█████████▌| 120/126 [23:03<01:08, 11.41s/it]

2024-09-11 22:13:19,187 - my_module - INFO - Epoch 6, loss = 2.2297833822975472


100%|██████████| 126/126 [24:05<00:00, 11.47s/it]

2024-09-11 22:14:09,172 - root - INFO - Epoch 7/10



  0%|          | 0/126 [00:00<?, ?it/s]

2024-09-11 22:14:20,579 - my_module - INFO - Epoch 7, loss = 2.3691906929016113


  8%|▊         | 10/126 [01:54<22:04, 11.42s/it]

2024-09-11 22:16:15,141 - my_module - INFO - Epoch 7, loss = 2.0878579183058306


 16%|█▌        | 20/126 [03:48<20:09, 11.41s/it]

2024-09-11 22:18:09,512 - my_module - INFO - Epoch 7, loss = 2.0245331014905656


 24%|██▍       | 30/126 [05:49<19:36, 12.26s/it]

2024-09-11 22:20:09,668 - my_module - INFO - Epoch 7, loss = 1.9935251089834398


 32%|███▏      | 40/126 [07:41<16:06, 11.23s/it]

2024-09-11 22:22:01,837 - my_module - INFO - Epoch 7, loss = 2.0046250471254674


 40%|███▉      | 50/126 [09:34<14:38, 11.55s/it]

2024-09-11 22:23:55,628 - my_module - INFO - Epoch 7, loss = 1.9946713050206502


 48%|████▊     | 60/126 [11:33<13:21, 12.14s/it]

2024-09-11 22:25:53,752 - my_module - INFO - Epoch 7, loss = 1.9937135919195708


 56%|█████▌    | 70/126 [13:23<10:19, 11.06s/it]

2024-09-11 22:27:44,092 - my_module - INFO - Epoch 7, loss = 1.9855892926874295


 63%|██████▎   | 80/126 [15:16<08:37, 11.26s/it]

2024-09-11 22:29:36,858 - my_module - INFO - Epoch 7, loss = 1.981680777337816


 71%|███████▏  | 90/126 [17:08<06:43, 11.22s/it]

2024-09-11 22:31:28,439 - my_module - INFO - Epoch 7, loss = 1.9944013841859587


 79%|███████▉  | 100/126 [18:59<04:51, 11.22s/it]

2024-09-11 22:33:19,805 - my_module - INFO - Epoch 7, loss = 2.0047611387649384


 87%|████████▋ | 110/126 [20:50<02:58, 11.17s/it]

2024-09-11 22:35:11,362 - my_module - INFO - Epoch 7, loss = 1.989807876380714


 95%|█████████▌| 120/126 [22:41<01:06, 11.07s/it]

2024-09-11 22:37:02,328 - my_module - INFO - Epoch 7, loss = 1.9884237996802843


100%|██████████| 126/126 [23:42<00:00, 11.29s/it]

2024-09-11 22:37:51,748 - root - INFO - Epoch 8/10



  0%|          | 0/126 [00:00<?, ?it/s]

2024-09-11 22:38:03,114 - my_module - INFO - Epoch 8, loss = 1.517648458480835


  8%|▊         | 10/126 [01:52<21:50, 11.29s/it]

2024-09-11 22:39:56,085 - my_module - INFO - Epoch 8, loss = 1.7903840325095437


 16%|█▌        | 20/126 [03:49<20:37, 11.68s/it]

2024-09-11 22:41:52,871 - my_module - INFO - Epoch 8, loss = 1.8019921552567255


 24%|██▍       | 30/126 [06:01<20:45, 12.97s/it]

2024-09-11 22:44:05,615 - my_module - INFO - Epoch 8, loss = 1.8406938352892477


 32%|███▏      | 40/126 [08:01<17:15, 12.04s/it]

2024-09-11 22:46:04,392 - my_module - INFO - Epoch 8, loss = 1.809131415878854


 40%|███▉      | 50/126 [09:58<14:52, 11.74s/it]

2024-09-11 22:48:01,950 - my_module - INFO - Epoch 8, loss = 1.7974420411914003


 48%|████▊     | 60/126 [12:14<18:53, 17.17s/it]

2024-09-11 22:50:41,506 - my_module - INFO - Epoch 8, loss = 1.7902268933468177


 56%|█████▌    | 70/126 [17:59<32:11, 34.49s/it]

2024-09-11 22:56:23,997 - my_module - INFO - Epoch 8, loss = 1.7951330131208394


 63%|██████▎   | 80/126 [22:19<13:34, 17.72s/it]

2024-09-11 23:00:27,541 - my_module - INFO - Epoch 8, loss = 1.796921357696439


 67%|██████▋   | 84/126 [23:53<11:56, 17.06s/it]


KeyboardInterrupt: 

In [10]:
torch.save(model.state_dict(), '../outputs/model.pth')


In [None]:
from loss.loss import CrossEntropyLoss
optim=torch.optim.Adam(model.parameters(),lr=1e-5)
criterion=CrossEntropyLoss()
running_loss=[]
for batch,(q,a) in enumerate(data_loader):
    
    
    q = {k: v.to(device) for k, v in q.items()}
    a = {k: v.to(device) for k, v in a.items()}
    _,_,sin_score=model(q,a)
    loss=criterion(sin_score)
    #loss.backward()
    #optim.step()
    running_loss.append(loss.item())
    print(loss.item())
    
    
    print(sin_score)