In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from tqdm.auto import tqdm
from torch.autograd import Variable
from torch.nn.functional import cross_entropy
from torch.utils.data import Dataset, DataLoader

In [11]:
params = {
    "epochs": 10,
    "lr": 1e-5,
    "num_feat": 10,
    "lin_dim": 10,
    "reg": 0.1,
    "batch_size": 32,
}

In [18]:
class rankqaDS(Dataset):
    def __getitem__(self, indx):
        return 0, 0, 0, 0
    
    def __len__(self):
        return 0

In [None]:
f_data = rankqaDS()
data_loader = DataLoader(f_data, batch_size = params["batch_size"])

In [6]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
print("device - " + str(device))

device - cpu


In [12]:
class ReRanker(nn.Module):
    def __init__(self, num_feat, lin_dim):
        super(ReRanker, self).__init__()
        self.l1 = nn.Linear(num_feat, lin_dim)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(lin_dim, 1)
        self.sig = nn.Sigmoid()

    def forward(x_head, x_body):
        out_head = self.l1(x_head)
        out_head = self.relu(out_head)
        out_head = self.l2(out_head)
        out_body = self.l1(x_body)
        out_body = self.relu(out_body)
        out_body = self.l2(out_body)
        out = self.sig(out_head - out_body)
        return out

In [14]:
model = ReRanker(params["num_feat"], params["lin_dim"])
model.to(device)

ReRanker(
  (l1): Linear(in_features=10, out_features=10, bias=True)
  (relu): ReLU()
  (l2): Linear(in_features=10, out_features=1, bias=True)
  (sig): Sigmoid()
)

In [15]:
loss_func = nn.MSELoss()
opt = optim.Adam(model.parameters(), lr=params["lr"])

In [None]:
for epoch in range(params["epochs"]):
    total_loss = 0
    for batch in tqdm(data_loader):
        x_head, x_body, target, i = batch
        x_head = x_head.to(device)
        x_body = x_body.to(device)
        target = target.to(device)
        y_pred = model.forward(x_head, x_body)
        loss = loss_func(y_pred[:, 0], target)
        reg_l2 = 0
        for p in model.paramters():
            reg_l2 += p.norm(2)
        loss += params["reg"] * reg_l2
        loss.backward()
        opt.step()
        opt.zero_grad()
        total_loss += loss.detach().item()
    
    if epoch%2 == 1 or epoch == params["epochs"] - 1:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': opt.state_dict(),
            'loss': total_loss,
        }, f"./checkpoints/model.pt")
    print(f"Total Loss - {total_loss}")