In [1]:
# Utilities
import time

# Pytorch Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader, random_split

import pytorch_lightning as pl

# Huggingface
from transformers import AutoTokenizer, AutoModel, AdamW

# Repository 
from utils import *
from dataset.triples import TriplesDataset
from model.cross_encoder import CrossEncoder
# from trainer.train import Trainer
%load_ext autoreload
%autoreload 2

from tqdm.notebook import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
CONFIG = {
    'model_name': 'distilbert-base-uncased',
    'query_maxlen': 64,
    'passage_maxlen': 64,
    'batch_size': 4,
    'epochs': 4,
    'learning_rate': 3e-5,
}

In [None]:
qidpidtriples = util.read_qidpidtriples('data/qidpidtriples.train.full.2.tsv')
train_queries = util.read_queries('data/queries.train.tsv')
collection = util.read_collection('data/collection.tsv')

In [None]:
tokenizer = AutoTokenizer.from_pretrained(CONFIG['model_name'])

triples_dataset = TriplesDataset(collection, train_queries, qidpidtriples, 
                                 tokenizer, CONFIG['query_maxlen'], CONFIG['passage_maxlen'])



In [None]:
triples_dataloader = DataLoader(triples_dataset, batch_size=4)
val_dataloader = DataLoader(triples_dataset, batch_size=4)
batch = next(iter(triples_dataloader))
for key, val in batch.items():
    print(f'{key}: {val.shape}')

In [None]:
trainer = pl.Trainer(fast_dev_run=False, gpus=1, log_every_n_steps=1, max_epochs=3)
model = CrossEncoder(**CONFIG)
# optimizer = torch.optim.Adam(params = model.parameters(), lr = 2e-5)
# criterion = nn.MarginRankingLoss()

In [None]:
trainer.fit(model=model, train_dataloaders=triples_dataloader, val_dataloaders=val_dataloader)