In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src.config import BATCH_SIZE
import random 
import numpy as np 
import torch 
from src.data_loader import read_data
from src.networks.unimodal_BERT import unimodal_dBERT_Dataset

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)
g = torch.Generator()
g.manual_seed(0)

def get_dataloader(data_type):
    data = read_data(data_type)
    dataset = unimodal_dBERT_Dataset(data)
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE,shuffle=True, worker_init_fn=seed_worker, generator=g)
    return data_loader

In [3]:
train_loader = get_dataloader('train')
val_loader = get_dataloader('val')
test_loader = get_dataloader('test')

100%|██████████| 687/687 [00:00<00:00, 6327.11it/s]
100%|██████████| 63/63 [00:00<00:00, 6997.36it/s]
100%|██████████| 200/200 [00:00<00:00, 7723.04it/s]


In [4]:
from src.trainer import Trainer
from src.networks.unimodal_BERT import unimodal_dBERT_Model, unimodal_dBERT_Input_transformer
import torch.nn as nn
import torch 

trainer = Trainer(10)

#Set data
trainer.set_data(train_loader, val_loader)

#set model 
model = unimodal_dBERT_Model().cuda()
input_transformer = unimodal_dBERT_Input_transformer()
trainer.set_model(model, input_transformer)

#Backpropagation
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters())
trainer.set_optimizer(optimizer, loss_fn)
trainer.train()

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 11/11 [00:12<00:00,  1.11s/it]
100%|██████████| 1/1 [00:00<00:00,  1.05it/s]




For epoch = 0
Training Loss = 0.532492992552844 | Training Accuracy = 0.0761605415860735
Validation Loss = 0.35481739044189453|Validation Accuracy = 0.07936507936507936




100%|██████████| 11/11 [00:10<00:00,  1.02it/s]
100%|██████████| 1/1 [00:00<00:00,  1.05it/s]




For epoch = 1
Training Loss = 0.30091872811317444 | Training Accuracy = 0.11995285299806575
Validation Loss = 0.31721651554107666|Validation Accuracy = 0.031746031746031744




100%|██████████| 11/11 [00:10<00:00,  1.02it/s]
100%|██████████| 1/1 [00:00<00:00,  1.04it/s]




For epoch = 2
Training Loss = 0.29573211615735834 | Training Accuracy = 0.05035058027079304
Validation Loss = 0.37674984335899353|Validation Accuracy = 0.031746031746031744




100%|██████████| 11/11 [00:10<00:00,  1.01it/s]
100%|██████████| 1/1 [00:00<00:00,  1.05it/s]




For epoch = 3
Training Loss = 0.35600813952359284 | Training Accuracy = 0.06159332688588008
Validation Loss = 0.4634316563606262|Validation Accuracy = 0.047619047619047616




100%|██████████| 11/11 [00:10<00:00,  1.01it/s]
100%|██████████| 1/1 [00:00<00:00,  1.04it/s]




For epoch = 4
Training Loss = 0.4199398458003998 | Training Accuracy = 0.05551861702127659
Validation Loss = 0.5172264575958252|Validation Accuracy = 0.12698412698412698




100%|██████████| 11/11 [00:11<00:00,  1.00s/it]
100%|██████████| 1/1 [00:00<00:00,  1.03it/s]




For epoch = 5
Training Loss = 0.4421052743088115 | Training Accuracy = 0.12279376208897484
Validation Loss = 0.5816947817802429|Validation Accuracy = 0.06349206349206349




100%|██████████| 11/11 [00:11<00:00,  1.02s/it]
100%|██████████| 1/1 [00:00<00:00,  1.01it/s]




For epoch = 6
Training Loss = 0.4895936304872686 | Training Accuracy = 0.0773392166344294
Validation Loss = 0.6713918447494507|Validation Accuracy = 0.031746031746031744




100%|██████████| 11/11 [00:11<00:00,  1.03s/it]
100%|██████████| 1/1 [00:00<00:00,  1.00it/s]




For epoch = 7
Training Loss = 0.5345091792670164 | Training Accuracy = 0.03551136363636364
Validation Loss = 0.7299129366874695|Validation Accuracy = 0.0




100%|██████████| 11/11 [00:11<00:00,  1.03s/it]
100%|██████████| 1/1 [00:01<00:00,  1.01s/it]




For epoch = 8
Training Loss = 0.5567641041495583 | Training Accuracy = 0.03511847195357834
Validation Loss = 0.7527030110359192|Validation Accuracy = 0.047619047619047616




100%|██████████| 11/11 [00:11<00:00,  1.04s/it]
100%|██████████| 1/1 [00:00<00:00,  1.00it/s]



For epoch = 9
Training Loss = 0.5747059529477899 | Training Accuracy = 0.029829545454545456
Validation Loss = 0.7320008277893066|Validation Accuracy = 0.047619047619047616







In [None]:
trainer.test(test_loader)