In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from src.config import BATCH_SIZE
import random 
import numpy as np 
import torch 
from src.data_loader import read_data
from src.networks.unimodal_BERT import unimodal_dBERT_Dataset

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)
g = torch.Generator()
g.manual_seed(0)

def get_dataloader(data_type):
    data = read_data(data_type)
    dataset = unimodal_dBERT_Dataset(data)
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE,shuffle=True, worker_init_fn=seed_worker, generator=g)
    return data_loader

In [None]:
train_loader = get_dataloader('train')
val_loader = get_dataloader('val')
test_loader = get_dataloader('test')

In [None]:
from src.trainer import Trainer
from src.networks.unimodal_BERT import unimodal_dBERT_Model, unimodal_dBERT_Input_transformer
import torch.nn as nn
import torch 

trainer = Trainer(12)

#Set data
trainer.set_data(train_loader, val_loader)

#set model 
model = unimodal_dBERT_Model().cuda()
input_transformer = unimodal_dBERT_Input_transformer()
trainer.set_model(model, input_transformer)

#Backpropagation
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
trainer.set_optimizer(optimizer, loss_fn)
trainer.train(l2_r=1e-3)
trainer.plot()

In [None]:
trainer.test(test_loader)

In [None]:
torch.save(model.state_dict(), f"Models/unimodal_dBERT_Model.model")