## Loading Dataset

In [None]:
from data import IndexDataset
import pandas as pd
from torch.utils.data import DataLoader

dataset_name = ""
column_name = ""
batch_size = 64

df = pd.read_csv(dataset_name)
dataset = IndexDataset(df, column_name)
dataloader = DataLoader(dataset, batch_size=batch_size)

## Training Loop

In [None]:
import lightning as L
import torch
from torch import nn, optim
import torch.nn.functional as F

class LitIndexer(L.LightningModule):
    def __init__(self, mapper, indexer):
        super().__init__()
        self.mapper = mapper
        self.indexer = indexer
        
    def loss(self, pred_idxs, real_idxs):
        return F.mse_loss(pred_idxs, real_idxs)

    def training_step(self, batch, batch_idx):
        strs, idxs = batch
        mapped_strs = self.mappers(strs)
        pred_idxs = self.indexer(mapped_strs)
        loss = self.loss(pred_idxs, idxs)
        self.log("train_loss", loss)
        
        return loss
    
    def test_step(self, batch, batch_idx):
        strs, idxs = batch
        mapped_strs = self.mappers(strs)
        pred_idxs = self.indexer(mapped_strs)
        loss = self.loss(pred_idxs, idxs)
        self.log("test_loss", loss)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


In [None]:
model = LitIndexer(#mapper, 
                   #indexer
                   )

trainer = L.Trainer()
trainer.fit(model, train_dataloaders=dataloader)

In [None]:
trainer.test(model, dataloaders=dataloader)