In [1]:
    #| default_exp run
#|  export
from fastcore.script import call_parse
def split_string(string):
    # Removing the parentheses and splitting the string by comma
    parts = string[1:-1].split(",")
    # Removing the whitespace and quotes from the parts
    parts = [part.strip().strip("'") for part in parts]
    return parts[0], parts[1]

def return_iters(db:str # Path to db
                 ):
    train_iter = []
    test_iter = []
    file = open(db, 'r', encoding='latin1')
    mapping = {
        "Libertarian Left": 1,
        "Libertarian Right": 2,
        "Authoritarian Left": 3,
        "Authoritarian Right": 4,
        "Centrist": 5,
        "Authoritarian Center": 6,
        "Left": 7,
        "Right": 8,
        "Libertarian Center": 9,
    }
    lines = file.readlines()
    for line in lines:
        opinion,text = split_string(line)
        train_iter+=[(mapping[opinion],text)]
        test_iter+=[(mapping[opinion],text)]
    train_iter = iter(train_iter)
    test_iter = iter(test_iter)
    file.close()
    return train_iter, test_iter

In [9]:
#|  export
from torchtext.data.utils import get_tokenizer
# from Political_Compass_AI.data_processing import return_iters
# from Political_Compass_AI.data_processing import split_string
from Political_Compass_AI.data_processing import yield_tokens
from Political_Compass_AI.data_processing import collate_batch
from Political_Compass_AI.model import TextClassificationModel
from Political_Compass_AI.training import train
from Political_Compass_AI.training import evaluate
from torchtext.data.functional import to_map_style_dataset
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
import time
import torch
import optuna
from optuna.trial import TrialState
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torchvision import datasets
from torchvision import transforms
import pandas as pd
def define_model(trial,vocab_size, emsize, num_class):
    model = TextClassificationModel(vocab_size, emsize, num_class)
    return model
def collate_batch(
        batch
):
    global text_pipeline
    global db
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    label_pipeline = lambda x: int(x) - 1
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)

def objective(
    trial,

):
    global text_pipeline
    global db
    # BATCH_SIZE = trial.suggest_int('n_epochs', 8, 64)
    BATCH_SIZE = trial.suggest_int('n_batch_size',32,128,32)
    db="../194511_DB_Hot_Top_New"
    tokenizer = get_tokenizer('basic_english')
    text_pipeline = lambda x: vocab(tokenizer(x))
    label_pipeline = lambda x: int(x) - 1
    train_iter, test_iter = return_iters(db)
    vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
    vocab.set_default_index(vocab["<unk>"])

    train_iter, test_iter = return_iters(db)
    dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)
    train_iter, test_iter = return_iters(db)
    num_class = len(set([label for (label, text) in train_iter]))
    vocab_size = len(vocab)
    emsize = trial.suggest_int("em_size",64,128,32)
    LR = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = define_model(trial,vocab_size, emsize, num_class).to(device)
    run_ledger = open("Run_Ledger.txt", 'a')
    criterion = torch.nn.CrossEntropyLoss()
    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD","Adagrad"])
    optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
    _optim = optimizer_name
    total_accu = None
    train_iter, test_iter = return_iters(db)
    train_dataset = to_map_style_dataset(train_iter)
    test_dataset = to_map_style_dataset(test_iter)
    num_train = int(len(train_dataset) * 0.95)
    split_train_, split_valid_ = \
        random_split(train_dataset, [num_train, len(train_dataset) - num_train])

    train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
                                  shuffle=True, collate_fn=collate_batch)
    valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                                  shuffle=True, collate_fn=collate_batch)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                                 shuffle=True, collate_fn=collate_batch)
    # EPOCHS = trial.suggest_int('n_epochs', 20, 40)
    EPOCHS = 20
    for epoch in range(1, EPOCHS + 1):
        epoch_start_time = time.time()
        train(train_dataloader, model, optimizer, epoch)
        accu_val = evaluate(valid_dataloader, model)
        if total_accu is not None and total_accu > accu_val:
            scheduler.step()
        else:
            total_accu = accu_val
        print('-' * 59)
        print('| end of epoch {:3d} | time: {:5.2f}s | '
              'valid accuracy {:8.3f} '.format(epoch,
                                               time.time() - epoch_start_time,
                                               accu_val))
        print('-' * 59)
        trial.report(accu_val, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
    df_Log = {"Database_file":[],"Epochs":[],"LR":[],"Batch_Size":[],
              "Final_accu":[],"Optimzer":[],"accu_test":[]}
    accu_test = evaluate(test_dataloader,model)
    out = 'test accuracy {:8.3f}'.format(accu_test)
    df_Log["Database_file"].append(db)
    df_Log["Epochs"].append(str(EPOCHS))
    df_Log["LR"].append( str(LR))
    df_Log["Batch_Size"].append(str(BATCH_SIZE))
    df_Log["Final_accu"].append(str(accu_val))
    df_Log["Optimzer"].append(optimizer_name)
    df_Log["accu_test"].append(accu_test)
    dataframe = pd.DataFrame(df_Log)
    dataframe.to_csv('Run_Ledger.csv',mode='a', index=False,sep="\t")
    return accu_val


In [10]:
# run("../uniqueDB.txt")
#torch.save(model.state_dict(), <path_to>)
# model.load_state_dict(torch.load(<path_to>))
from optuna.pruners import MedianPruner

study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=100, timeout=None)

pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

print("Study statistics: ")
print("  Number of finished trials: ", len(study.trials))
print("  Number of pruned trials: ", len(pruned_trials))
print("  Number of complete trials: ", len(complete_trials))

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

[32m[I 2023-03-27 15:34:39,387][0m A new study created in memory with name: no-name-d545e30c-6c7b-4831-bd63-e80bb4100b0c[0m


| epoch   1 |    50/ 1925 batches | accuracy    0.121
| epoch   1 |   100/ 1925 batches | accuracy    0.113
| epoch   1 |   150/ 1925 batches | accuracy    0.114
| epoch   1 |   200/ 1925 batches | accuracy    0.116
| epoch   1 |   250/ 1925 batches | accuracy    0.122
| epoch   1 |   300/ 1925 batches | accuracy    0.115
| epoch   1 |   350/ 1925 batches | accuracy    0.111
| epoch   1 |   400/ 1925 batches | accuracy    0.118
| epoch   1 |   450/ 1925 batches | accuracy    0.123
| epoch   1 |   500/ 1925 batches | accuracy    0.126
| epoch   1 |   550/ 1925 batches | accuracy    0.120
| epoch   1 |   600/ 1925 batches | accuracy    0.111
| epoch   1 |   650/ 1925 batches | accuracy    0.122
| epoch   1 |   700/ 1925 batches | accuracy    0.122
| epoch   1 |   750/ 1925 batches | accuracy    0.122
| epoch   1 |   800/ 1925 batches | accuracy    0.116
| epoch   1 |   850/ 1925 batches | accuracy    0.120
| epoch   1 |   900/ 1925 batches | accuracy    0.121
| epoch   1 |   950/ 1925 ba

[32m[I 2023-03-27 15:38:42,988][0m Trial 0 finished with value: 0.13376516553567758 and parameters: {'n_batch_size': 96, 'em_size': 64, 'lr': 5.954124879893699e-05, 'optimizer': 'Adagrad'}. Best is trial 0 with value: 0.13376516553567758.[0m


| epoch   1 |    50/ 5775 batches | accuracy    0.115
| epoch   1 |   100/ 5775 batches | accuracy    0.117
| epoch   1 |   150/ 5775 batches | accuracy    0.126
| epoch   1 |   200/ 5775 batches | accuracy    0.119
| epoch   1 |   250/ 5775 batches | accuracy    0.122
| epoch   1 |   300/ 5775 batches | accuracy    0.128
| epoch   1 |   350/ 5775 batches | accuracy    0.127
| epoch   1 |   400/ 5775 batches | accuracy    0.116
| epoch   1 |   450/ 5775 batches | accuracy    0.125
| epoch   1 |   500/ 5775 batches | accuracy    0.125
| epoch   1 |   550/ 5775 batches | accuracy    0.115
| epoch   1 |   600/ 5775 batches | accuracy    0.107
| epoch   1 |   650/ 5775 batches | accuracy    0.122
| epoch   1 |   700/ 5775 batches | accuracy    0.119
| epoch   1 |   750/ 5775 batches | accuracy    0.128
| epoch   1 |   800/ 5775 batches | accuracy    0.108
| epoch   1 |   850/ 5775 batches | accuracy    0.125
| epoch   1 |   900/ 5775 batches | accuracy    0.109
| epoch   1 |   950/ 5775 ba

[32m[I 2023-03-27 15:47:51,663][0m Trial 1 finished with value: 0.11464116800329015 and parameters: {'n_batch_size': 32, 'em_size': 128, 'lr': 2.1255384262904024e-05, 'optimizer': 'SGD'}. Best is trial 0 with value: 0.13376516553567758.[0m


| epoch   1 |    50/ 1444 batches | accuracy    0.105
| epoch   1 |   100/ 1444 batches | accuracy    0.102
| epoch   1 |   150/ 1444 batches | accuracy    0.110
| epoch   1 |   200/ 1444 batches | accuracy    0.103
| epoch   1 |   250/ 1444 batches | accuracy    0.102
| epoch   1 |   300/ 1444 batches | accuracy    0.102
| epoch   1 |   350/ 1444 batches | accuracy    0.100
| epoch   1 |   400/ 1444 batches | accuracy    0.102
| epoch   1 |   450/ 1444 batches | accuracy    0.098
| epoch   1 |   500/ 1444 batches | accuracy    0.103
| epoch   1 |   550/ 1444 batches | accuracy    0.102
| epoch   1 |   600/ 1444 batches | accuracy    0.102
| epoch   1 |   650/ 1444 batches | accuracy    0.095
| epoch   1 |   700/ 1444 batches | accuracy    0.102
| epoch   1 |   750/ 1444 batches | accuracy    0.109
| epoch   1 |   800/ 1444 batches | accuracy    0.098
| epoch   1 |   850/ 1444 batches | accuracy    0.104
| epoch   1 |   900/ 1444 batches | accuracy    0.100
| epoch   1 |   950/ 1444 ba

[32m[I 2023-03-27 15:51:38,079][0m Trial 2 finished with value: 0.10189183631503188 and parameters: {'n_batch_size': 128, 'em_size': 128, 'lr': 1.6849196017682853e-05, 'optimizer': 'SGD'}. Best is trial 0 with value: 0.13376516553567758.[0m


| epoch   1 |    50/ 1925 batches | accuracy    0.132
| epoch   1 |   100/ 1925 batches | accuracy    0.128
| epoch   1 |   150/ 1925 batches | accuracy    0.134
| epoch   1 |   200/ 1925 batches | accuracy    0.127
| epoch   1 |   250/ 1925 batches | accuracy    0.128
| epoch   1 |   300/ 1925 batches | accuracy    0.128
| epoch   1 |   350/ 1925 batches | accuracy    0.131
| epoch   1 |   400/ 1925 batches | accuracy    0.132
| epoch   1 |   450/ 1925 batches | accuracy    0.124
| epoch   1 |   500/ 1925 batches | accuracy    0.131
| epoch   1 |   550/ 1925 batches | accuracy    0.134
| epoch   1 |   600/ 1925 batches | accuracy    0.133
| epoch   1 |   650/ 1925 batches | accuracy    0.130
| epoch   1 |   700/ 1925 batches | accuracy    0.132
| epoch   1 |   750/ 1925 batches | accuracy    0.129
| epoch   1 |   800/ 1925 batches | accuracy    0.137
| epoch   1 |   850/ 1925 batches | accuracy    0.120
| epoch   1 |   900/ 1925 batches | accuracy    0.132
| epoch   1 |   950/ 1925 ba

[32m[I 2023-03-27 15:55:45,308][0m Trial 3 finished with value: 0.14342998149290562 and parameters: {'n_batch_size': 96, 'em_size': 64, 'lr': 3.4110334006617385e-05, 'optimizer': 'Adagrad'}. Best is trial 3 with value: 0.14342998149290562.[0m


| epoch   1 |    50/ 1444 batches | accuracy    0.119
| epoch   1 |   100/ 1444 batches | accuracy    0.116
| epoch   1 |   150/ 1444 batches | accuracy    0.121
| epoch   1 |   200/ 1444 batches | accuracy    0.132
| epoch   1 |   250/ 1444 batches | accuracy    0.144
| epoch   1 |   300/ 1444 batches | accuracy    0.142
| epoch   1 |   350/ 1444 batches | accuracy    0.140
| epoch   1 |   400/ 1444 batches | accuracy    0.152
| epoch   1 |   450/ 1444 batches | accuracy    0.144
| epoch   1 |   500/ 1444 batches | accuracy    0.160
| epoch   1 |   550/ 1444 batches | accuracy    0.145
| epoch   1 |   600/ 1444 batches | accuracy    0.155
| epoch   1 |   650/ 1444 batches | accuracy    0.151
| epoch   1 |   700/ 1444 batches | accuracy    0.169
| epoch   1 |   750/ 1444 batches | accuracy    0.157
| epoch   1 |   800/ 1444 batches | accuracy    0.159
| epoch   1 |   850/ 1444 batches | accuracy    0.156
| epoch   1 |   900/ 1444 batches | accuracy    0.163
| epoch   1 |   950/ 1444 ba

[32m[I 2023-03-27 16:00:26,749][0m Trial 4 finished with value: 0.16913427925149085 and parameters: {'n_batch_size': 128, 'em_size': 96, 'lr': 7.584371845419337e-05, 'optimizer': 'Adam'}. Best is trial 4 with value: 0.16913427925149085.[0m


| epoch   1 |    50/ 2888 batches | accuracy    0.145
| epoch   1 |   100/ 2888 batches | accuracy    0.156
| epoch   1 |   150/ 2888 batches | accuracy    0.162
| epoch   1 |   200/ 2888 batches | accuracy    0.174
| epoch   1 |   250/ 2888 batches | accuracy    0.180
| epoch   1 |   300/ 2888 batches | accuracy    0.172
| epoch   1 |   350/ 2888 batches | accuracy    0.158
| epoch   1 |   400/ 2888 batches | accuracy    0.167
| epoch   1 |   450/ 2888 batches | accuracy    0.185
| epoch   1 |   500/ 2888 batches | accuracy    0.159
| epoch   1 |   550/ 2888 batches | accuracy    0.179
| epoch   1 |   600/ 2888 batches | accuracy    0.178
| epoch   1 |   650/ 2888 batches | accuracy    0.169
| epoch   1 |   700/ 2888 batches | accuracy    0.176
| epoch   1 |   750/ 2888 batches | accuracy    0.162
| epoch   1 |   800/ 2888 batches | accuracy    0.173
| epoch   1 |   850/ 2888 batches | accuracy    0.174
| epoch   1 |   900/ 2888 batches | accuracy    0.168
| epoch   1 |   950/ 2888 ba

[32m[I 2023-03-27 16:07:43,419][0m Trial 5 finished with value: 0.1928850503804236 and parameters: {'n_batch_size': 64, 'em_size': 96, 'lr': 0.0008551851682280133, 'optimizer': 'Adam'}. Best is trial 5 with value: 0.1928850503804236.[0m


| epoch   1 |    50/ 2888 batches | accuracy    0.099
| epoch   1 |   100/ 2888 batches | accuracy    0.098
| epoch   1 |   150/ 2888 batches | accuracy    0.107
| epoch   1 |   200/ 2888 batches | accuracy    0.100
| epoch   1 |   250/ 2888 batches | accuracy    0.106
| epoch   1 |   300/ 2888 batches | accuracy    0.113
| epoch   1 |   350/ 2888 batches | accuracy    0.098
| epoch   1 |   400/ 2888 batches | accuracy    0.098
| epoch   1 |   450/ 2888 batches | accuracy    0.096
| epoch   1 |   500/ 2888 batches | accuracy    0.106
| epoch   1 |   550/ 2888 batches | accuracy    0.106
| epoch   1 |   600/ 2888 batches | accuracy    0.106
| epoch   1 |   650/ 2888 batches | accuracy    0.107
| epoch   1 |   700/ 2888 batches | accuracy    0.118
| epoch   1 |   750/ 2888 batches | accuracy    0.114
| epoch   1 |   800/ 2888 batches | accuracy    0.106
| epoch   1 |   850/ 2888 batches | accuracy    0.113
| epoch   1 |   900/ 2888 batches | accuracy    0.109
| epoch   1 |   950/ 2888 ba

[32m[I 2023-03-27 16:08:05,476][0m Trial 6 pruned. [0m


-----------------------------------------------------------
| end of epoch   1 | time: 17.41s | valid accuracy    0.116 
-----------------------------------------------------------
| epoch   1 |    50/ 2888 batches | accuracy    0.109
| epoch   1 |   100/ 2888 batches | accuracy    0.130
| epoch   1 |   150/ 2888 batches | accuracy    0.139
| epoch   1 |   200/ 2888 batches | accuracy    0.125
| epoch   1 |   250/ 2888 batches | accuracy    0.130
| epoch   1 |   300/ 2888 batches | accuracy    0.123
| epoch   1 |   350/ 2888 batches | accuracy    0.146
| epoch   1 |   400/ 2888 batches | accuracy    0.128
| epoch   1 |   450/ 2888 batches | accuracy    0.143
| epoch   1 |   500/ 2888 batches | accuracy    0.135
| epoch   1 |   550/ 2888 batches | accuracy    0.134
| epoch   1 |   600/ 2888 batches | accuracy    0.142
| epoch   1 |   650/ 2888 batches | accuracy    0.139
| epoch   1 |   700/ 2888 batches | accuracy    0.138
| epoch   1 |   750/ 2888 batches | accuracy    0.147
| epoch  

[32m[I 2023-03-27 16:15:16,612][0m Trial 7 finished with value: 0.18075262183837137 and parameters: {'n_batch_size': 64, 'em_size': 128, 'lr': 2.8050026350093766e-05, 'optimizer': 'RMSprop'}. Best is trial 5 with value: 0.1928850503804236.[0m


| epoch   1 |    50/ 5775 batches | accuracy    0.102
| epoch   1 |   100/ 5775 batches | accuracy    0.099
| epoch   1 |   150/ 5775 batches | accuracy    0.107
| epoch   1 |   200/ 5775 batches | accuracy    0.106
| epoch   1 |   250/ 5775 batches | accuracy    0.100
| epoch   1 |   300/ 5775 batches | accuracy    0.097
| epoch   1 |   350/ 5775 batches | accuracy    0.100
| epoch   1 |   400/ 5775 batches | accuracy    0.106
| epoch   1 |   450/ 5775 batches | accuracy    0.102
| epoch   1 |   500/ 5775 batches | accuracy    0.095
| epoch   1 |   550/ 5775 batches | accuracy    0.110
| epoch   1 |   600/ 5775 batches | accuracy    0.098
| epoch   1 |   650/ 5775 batches | accuracy    0.107
| epoch   1 |   700/ 5775 batches | accuracy    0.115
| epoch   1 |   750/ 5775 batches | accuracy    0.112
| epoch   1 |   800/ 5775 batches | accuracy    0.101
| epoch   1 |   850/ 5775 batches | accuracy    0.109
| epoch   1 |   900/ 5775 batches | accuracy    0.102
| epoch   1 |   950/ 5775 ba

[32m[I 2023-03-27 16:15:47,198][0m Trial 8 pruned. [0m


-----------------------------------------------------------
| end of epoch   1 | time: 25.65s | valid accuracy    0.113 
-----------------------------------------------------------
| epoch   1 |    50/ 1444 batches | accuracy    0.113
| epoch   1 |   100/ 1444 batches | accuracy    0.119
| epoch   1 |   150/ 1444 batches | accuracy    0.112
| epoch   1 |   200/ 1444 batches | accuracy    0.128
| epoch   1 |   250/ 1444 batches | accuracy    0.134
| epoch   1 |   300/ 1444 batches | accuracy    0.139
| epoch   1 |   350/ 1444 batches | accuracy    0.143
| epoch   1 |   400/ 1444 batches | accuracy    0.137
| epoch   1 |   450/ 1444 batches | accuracy    0.138
| epoch   1 |   500/ 1444 batches | accuracy    0.150
| epoch   1 |   550/ 1444 batches | accuracy    0.142
| epoch   1 |   600/ 1444 batches | accuracy    0.143
| epoch   1 |   650/ 1444 batches | accuracy    0.155
| epoch   1 |   700/ 1444 batches | accuracy    0.155
| epoch   1 |   750/ 1444 batches | accuracy    0.153
| epoch  

[32m[I 2023-03-27 16:20:14,711][0m Trial 9 finished with value: 0.17324696689286448 and parameters: {'n_batch_size': 128, 'em_size': 128, 'lr': 2.7343534037458656e-05, 'optimizer': 'RMSprop'}. Best is trial 5 with value: 0.1928850503804236.[0m


| epoch   1 |    50/ 2888 batches | accuracy    0.117
| epoch   1 |   100/ 2888 batches | accuracy    0.142
| epoch   1 |   150/ 2888 batches | accuracy    0.155
| epoch   1 |   200/ 2888 batches | accuracy    0.167
| epoch   1 |   250/ 2888 batches | accuracy    0.159
| epoch   1 |   300/ 2888 batches | accuracy    0.178
| epoch   1 |   350/ 2888 batches | accuracy    0.177
| epoch   1 |   400/ 2888 batches | accuracy    0.174
| epoch   1 |   450/ 2888 batches | accuracy    0.165
| epoch   1 |   500/ 2888 batches | accuracy    0.173
| epoch   1 |   550/ 2888 batches | accuracy    0.168
| epoch   1 |   600/ 2888 batches | accuracy    0.182
| epoch   1 |   650/ 2888 batches | accuracy    0.177
| epoch   1 |   700/ 2888 batches | accuracy    0.170
| epoch   1 |   750/ 2888 batches | accuracy    0.172
| epoch   1 |   800/ 2888 batches | accuracy    0.172
| epoch   1 |   850/ 2888 batches | accuracy    0.158
| epoch   1 |   900/ 2888 batches | accuracy    0.163
| epoch   1 |   950/ 2888 ba

[32m[I 2023-03-27 16:27:30,782][0m Trial 10 finished with value: 0.189389265885256 and parameters: {'n_batch_size': 64, 'em_size': 96, 'lr': 0.0009225905356101448, 'optimizer': 'Adam'}. Best is trial 5 with value: 0.1928850503804236.[0m


| epoch   1 |    50/ 2888 batches | accuracy    0.133
| epoch   1 |   100/ 2888 batches | accuracy    0.163
| epoch   1 |   150/ 2888 batches | accuracy    0.162
| epoch   1 |   200/ 2888 batches | accuracy    0.177
| epoch   1 |   250/ 2888 batches | accuracy    0.172
| epoch   1 |   300/ 2888 batches | accuracy    0.160
| epoch   1 |   350/ 2888 batches | accuracy    0.180
| epoch   1 |   400/ 2888 batches | accuracy    0.164
| epoch   1 |   450/ 2888 batches | accuracy    0.166
| epoch   1 |   500/ 2888 batches | accuracy    0.170
| epoch   1 |   550/ 2888 batches | accuracy    0.166
| epoch   1 |   600/ 2888 batches | accuracy    0.171
| epoch   1 |   650/ 2888 batches | accuracy    0.164
| epoch   1 |   700/ 2888 batches | accuracy    0.172
| epoch   1 |   750/ 2888 batches | accuracy    0.165
| epoch   1 |   800/ 2888 batches | accuracy    0.188
| epoch   1 |   850/ 2888 batches | accuracy    0.181
| epoch   1 |   900/ 2888 batches | accuracy    0.162
| epoch   1 |   950/ 2888 ba

[32m[I 2023-03-27 16:34:46,423][0m Trial 11 finished with value: 0.19144560970594282 and parameters: {'n_batch_size': 64, 'em_size': 96, 'lr': 0.0009340811316633697, 'optimizer': 'Adam'}. Best is trial 5 with value: 0.1928850503804236.[0m


| epoch   1 |    50/ 2888 batches | accuracy    0.131
| epoch   1 |   100/ 2888 batches | accuracy    0.168
| epoch   1 |   150/ 2888 batches | accuracy    0.162
| epoch   1 |   200/ 2888 batches | accuracy    0.163
| epoch   1 |   250/ 2888 batches | accuracy    0.162
| epoch   1 |   300/ 2888 batches | accuracy    0.167
| epoch   1 |   350/ 2888 batches | accuracy    0.164
| epoch   1 |   400/ 2888 batches | accuracy    0.167
| epoch   1 |   450/ 2888 batches | accuracy    0.160
| epoch   1 |   500/ 2888 batches | accuracy    0.167
| epoch   1 |   550/ 2888 batches | accuracy    0.167
| epoch   1 |   600/ 2888 batches | accuracy    0.170
| epoch   1 |   650/ 2888 batches | accuracy    0.171
| epoch   1 |   700/ 2888 batches | accuracy    0.177
| epoch   1 |   750/ 2888 batches | accuracy    0.171
| epoch   1 |   800/ 2888 batches | accuracy    0.175
| epoch   1 |   850/ 2888 batches | accuracy    0.176
| epoch   1 |   900/ 2888 batches | accuracy    0.177
| epoch   1 |   950/ 2888 ba

[32m[I 2023-03-27 16:42:02,217][0m Trial 12 finished with value: 0.1855850298169854 and parameters: {'n_batch_size': 64, 'em_size': 96, 'lr': 0.0009366470982505514, 'optimizer': 'Adam'}. Best is trial 5 with value: 0.1928850503804236.[0m


| epoch   1 |    50/ 5775 batches | accuracy    0.139
| epoch   1 |   100/ 5775 batches | accuracy    0.149
| epoch   1 |   150/ 5775 batches | accuracy    0.156
| epoch   1 |   200/ 5775 batches | accuracy    0.154
| epoch   1 |   250/ 5775 batches | accuracy    0.143
| epoch   1 |   300/ 5775 batches | accuracy    0.163
| epoch   1 |   350/ 5775 batches | accuracy    0.146
| epoch   1 |   400/ 5775 batches | accuracy    0.174
| epoch   1 |   450/ 5775 batches | accuracy    0.171
| epoch   1 |   500/ 5775 batches | accuracy    0.168
| epoch   1 |   550/ 5775 batches | accuracy    0.177
| epoch   1 |   600/ 5775 batches | accuracy    0.154
| epoch   1 |   650/ 5775 batches | accuracy    0.171
| epoch   1 |   700/ 5775 batches | accuracy    0.158
| epoch   1 |   750/ 5775 batches | accuracy    0.179
| epoch   1 |   800/ 5775 batches | accuracy    0.171
| epoch   1 |   850/ 5775 batches | accuracy    0.162
| epoch   1 |   900/ 5775 batches | accuracy    0.184
| epoch   1 |   950/ 5775 ba

[32m[I 2023-03-27 16:52:09,138][0m Trial 13 finished with value: 0.18579066419905407 and parameters: {'n_batch_size': 32, 'em_size': 64, 'lr': 0.00045530114208440866, 'optimizer': 'Adam'}. Best is trial 5 with value: 0.1928850503804236.[0m
[33m[W 2023-03-27 16:52:14,755][0m Trial 14 failed with parameters: {'n_batch_size': 96, 'em_size': 96, 'lr': 0.00035321546445956416, 'optimizer': 'Adam'} because of the following error: KeyboardInterrupt().[0m
Traceback (most recent call last):
  File "C:\Users\turet\anaconda3\lib\site-packages\optuna\study\_optimize.py", line 200, in _run_trial
    value_or_values = func(trial)
  File "<ipython-input-9-ac9b3aac0eec>", line 95, in objective
    train(train_dataloader, model, optimizer, epoch)
  File "c:\users\turet\documents\docs\school\semsetera2023\pcm\political_compass_ai\Political_Compass_AI\training.py", line 25, in train
    for idx, (label, text, offsets) in enumerate(dataloader):
  File "C:\Users\turet\anaconda3\lib\site-packages\torch

| epoch   1 |    50/ 1925 batches | accuracy    0.110


KeyboardInterrupt: 

In [15]:
# def predict(text, text_pipeline):
#     with torch.no_grad():
#         text = torch.tensor(text_pipeline(text))
#         output = model(text, torch.tensor([0]))
#         return output.argmax(1).item() + 1
#
# mapping = {
# 1:"Libertarian Left",
# 2:"Libertarian Right",
# 3:"Authoritarian Left",
# 4:"Authoritarian Right",
# }
# model = model.to("cpu")
# # ex_text_str = """
# # """
# # https://old.reddit.com/r/PoliticalCompassMemes/comments/x774os/conservative_you_say_sounds_fine_to_me/inbbz52/
# ex_text_str = """
# deo's mom
# """
# print("This is a %s comment" % mapping[predict(ex_text_str, text_pipeline)])

This is a Libertarian Right comment
