In [None]:
from dataset import FFN2Dataset
from tokenizer import FFN2Tokenizer
from torch.optim import Adam
from datasets import load_dataset
import torch
from models import MeanClassifier, CLSClassifier, LogRegCLSClassifier
from train import train
from torch.optim.lr_scheduler import CyclicLR

bert_map = {
    'bengali': 'google/muril-base-cased', 
    'english': 'bert-base-uncased', 
    'indonesian': 'cahya/bert-base-indonesian-522M', 
    'arabic': 'asafaya/bert-base-arabic'
}
language = "bengali"
languages = ["bengali", "indonesian", "arabic"]
bert = bert_map[language]
device = 'cuda'
input_dim = 768
hidden_dim = 50
lr = 3e-2
batch_size = 32
epochs = 3

dataset = load_dataset("copenlu/answerable_tydiqa")
language_dataset = dataset.filter(lambda row: row['language'] == language)

train_set = language_dataset["train"]
validation_set = language_dataset["validation"]

tokenizer = FFN2Tokenizer(bert)
train_set = FFN2Dataset(train_set, tokenizer)
validation_set = FFN2Dataset(validation_set, tokenizer)

# Mean FFN

In [None]:
mean_model = MeanClassifier(bert)
optimizer = Adam(mean_model.parameters(), lr=lr)
scheduler = CyclicLR(optimizer, base_lr=0., max_lr=lr, step_size_up=1, step_size_down=len(train_set)*epochs, cycle_momentum=False)
best_model = train(mean_model, optimizer, scheduler, train_set, validation_set, epochs=epochs, batch_size=batch_size, lr=lr, device=device)
torch.save(best_model, 'mean_bert_classifier.pt')

# CLS FFN

In [None]:
cls_model = CLSClassifier(bert)
optimizer = Adam(cls_model.parameters(), lr=lr)
scheduler = CyclicLR(optimizer, base_lr=0., max_lr=lr, step_size_up=1, step_size_down=len(train_set)*epochs, cycle_momentum=False)
best_model = train(cls_model, optimizer, scheduler, train_set, validation_set, epochs=epochs, batch_size=batch_size, lr=lr, device=device)
torch.save(best_model, 'cls_bert_classifier.pt')

# CLS Log Reg

In [None]:
cls_model = LogRegCLSClassifier(bert)
optimizer = Adam(cls_model.parameters(), lr=lr)
scheduler = CyclicLR(optimizer, base_lr=0., max_lr=lr, step_size_up=1, step_size_down=len(train_set)*epochs, cycle_momentum=False)
best_model = train(cls_model, optimizer, scheduler, train_set, validation_set, epochs=epochs, batch_size=batch_size, lr=lr, device=device)
torch.save(best_model, 'cls_bert_logreg_classifier.pt')