# Pytorch Training UI

### Import Libraries

In [1]:
# detect whether this is a google environment

COLAB_ENVIRONMENT = False

try:
    from google.colab import drive
    drive.mount('/content/drive')
    COLAB_ENVIRONMENT = True
except:
    pass

In [2]:
import sys
import os
import pickle

if COLAB_ENVIRONMENT:
    py_file_location = "./drive/MyDrive/LAB/COMP90051-A1__Groupwork__Py/PrivatePackages/pytorch" # my private packages are stored here
    home_directory = './drive/MyDrive/LAB/COMP90051-A1__Groupwork__Py/' # my home directory is stored in ./LAB of google drive
    !pip install einops
else:
    py_file_location = './PrivatePackages/pytorch'
    home_directory = './'

sys.path.append(os.path.abspath(py_file_location))

from environment import *
from utils import *

from sklearn.model_selection import train_test_split

In [3]:
from model.model_class import LSTM, BERT, LSTM_DANN, BERT_DANN, LSTM_DCE_DANN, BERT_DCE_DANN, LSTM_Hinge, BERT_Hinge

### Set Seed and Load Data

In [4]:
SEED = 2608

In [5]:
data1 = []
with open(home_directory + '/data/raw/comp90051-2024s1-project-1/domain1_train_data.json', 'r') as f:
    for line in f:
        data1.append(json.loads(line))

data2 = []
with open(home_directory + './data/raw/comp90051-2024s1-project-1/domain2_train_data.json', 'r') as f:
    for line in f:
        data2.append(json.loads(line))

data_test = []
with open(home_directory + '/data/raw/comp90051-2024s1-project-1/test_data.json', 'r') as f:
    for line in f:
        data_test.append(json.loads(line))

# create domain labels for data
for i in range(len(data1)):
    data1[i]['domain'] = 0
for i in range(len(data2)):
    data2[i]['domain'] = 1

In [6]:
# Train Val Test Split

# get labels for stratification
label1 = [instance['label'] for instance in data1]
label2 = [instance['label'] for instance in data2]

train_ix_1, val_test_ix_1 = train_test_split(range(len(data1)), test_size=0.3, random_state=SEED, stratify = label1)
train_ix_2, val_test_ix_2 = train_test_split(range(len(data2)), test_size=0.3, random_state=SEED, stratify = label2)
val_ix_1, test_ix_1 = train_test_split(val_test_ix_1, test_size=0.5, random_state=SEED, stratify = [data1[i]['label'] for i in val_test_ix_1])
val_ix_2, test_ix_2 = train_test_split(val_test_ix_2, test_size=0.5, random_state=SEED, stratify = [data2[i]['label'] for i in val_test_ix_2])

# split data according to the index from train_test_split
train_data_1 = [data1[i] for i in train_ix_1]
val_data_1 = [data1[i] for i in val_ix_1]
test_data_1 = [data1[i] for i in test_ix_1]
train_data_2 = [data2[i] for i in train_ix_2]
val_data_2 = [data2[i] for i in val_ix_2]
test_data_2 = [data2[i] for i in test_ix_2]

# combine the data
train_data = train_data_1 + train_data_2
val_data = val_data_1 + val_data_2
test_data = test_data_1 + test_data_2

---
### Models

#### 1. Pure Classification

In [7]:
# BERT - CELoss

MAX_SENTENCE_LENGTH = 256
MIN_FREQUENCY = 40 # because 40 is statistical sample requirement
MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE = False
LOW_FREQ_TOKEN = False
CLS = True
PAD_FRONT = False
W2V_CONTEXT_WINDOW = 5 # 2 to left, 2 to right
cropped_train_data = crop_sentence_length(train_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE)
cropped_val_data = crop_sentence_length(val_data, max_sentence_length =  MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_test_data = crop_sentence_length(test_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_future_data = crop_sentence_length(data_test, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
raw_token_pytorch_map = get_raw_token_pytorch_map(data = cropped_train_data, min_frequency = MIN_FREQUENCY)
train_x, train_y, val_x, val_y, test_x, test_y, train_dom, val_dom, test_dom, future_x = Data_Factory(cropped_train_data, \
                                                              cropped_val_data, \
                                                                cropped_test_data, \
                                                                    cropped_future_data, \
                                                                        MAX_SENTENCE_LENGTH, \
                                                                            raw_token_pytorch_map, \
                                                                                CLS=CLS, \
                                                                                    low_freq_special_token=LOW_FREQ_TOKEN, \
                                                                                        pad_front=PAD_FRONT)
pos_prior, neg_prior = get_distribution(train_y)
print('class prior:', pos_prior, neg_prior)
pos_dom_prior, neg_dom_prior = get_distribution(train_dom)
print('domain prior:', pos_dom_prior, neg_dom_prior)
dom1_pos_prior, dom1_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label1])
print('dom1 class prior:', dom1_pos_prior, dom1_neg_prior)
dom2_pos_prior, dom2_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label2])
print('dom2 class prior:', dom2_pos_prior, dom2_neg_prior)
pretrain_x, pretrain_y, pretrain_mask, pretrain_dom, preval_x, preval_y, preval_mask, preval_dom = BERT_pretrain_DataFactory(train_data, val_data, SEED, raw_token_pytorch_map, MAX_SENTENCE_LENGTH)

print('---')

class BERT_config:
    # ----------------- architectual hyperparameters ----------------- #
    d_model = 256
    d_ff = 1024 # = 4* d_model
    n_heads = 8
    dropout = 0.1
    e_layers = 8
    embedding_aggregation = 'cls' # TODO
    n_mlp_layers = 1
    res_learning = False
    activation = nn.ReLU()
    mask_flag = False # causal mask
    train_embedding = True
    # ----------------- optimisation hyperparameters ----------------- #
    random_state = SEED
    batch_size = 8
    epochs = 32
    lr = 1e-5
    patience = 10
    loss = nn.CrossEntropyLoss()
    # loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    validation_loss = nn.CrossEntropyLoss()
    # validation_loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    pretrain_loss = nn.CrossEntropyLoss()
    pretrain_validation_loss = nn.CrossEntropyLoss()
    regularisation_loss = None
    scheduler = False
    grad_clip = False
    # ----------------- operation hyperparameters ----------------- #
    d_output = 2
    seq_len = MAX_SENTENCE_LENGTH
    n_unique_tokens = len(raw_token_pytorch_map)
    # ----------------- saving hyperparameters ----------------- #
    rootpath = home_directory + './'
    saving_address = home_directory + f'./results/'
    name = f'BERT_Classifier'

model = BERT(BERT_config) # initialise the model

# train the model (all cells except this one will print training log and evaluation at each batch)
# pretrain_best_epoch = model.fit_pretrain(pretrain_x, pretrain_y, pretrain_dom, pretrain_mask, preval_x, preval_y, preval_dom, preval_mask)
# print()

# as model automatically saves best epoch, will now load the best epoch and evaluate on test set
# model.load()
# model.eval_pretrain(preval_x, preval_y, preval_dom, preval_mask, pretrain_best_epoch, evaluation_mode = True)

# train the model (all cells except this one will print training log and evaluation at each batch)
best_epoch = model.fit(train_x, train_y, train_dom, val_x, val_y, val_dom)
print()

# as model automatically saves best epoch, will now load the best epoch and evaluate on test set
model.load()
model.eval(val_x, val_y, val_dom, best_epoch, evaluation_mode = True)
model.eval(test_x, test_y, test_dom, best_epoch, evaluation_mode = True)

100%|██████████| 12600/12600 [00:00<00:00, 28934.71it/s]
100%|██████████| 2700/2700 [00:00<00:00, 93906.22it/s]
100%|██████████| 2700/2700 [00:00<00:00, 120085.05it/s]
 11%|█▏        | 453/4000 [00:00<00:00, 64558.45it/s]


KeyboardInterrupt: 

In [None]:
# BERT - WCELoss

MAX_SENTENCE_LENGTH = 256
MIN_FREQUENCY = 40 # because 40 is statistical sample requirement
MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE = False
LOW_FREQ_TOKEN = False
CLS = True
PAD_FRONT = False
W2V_CONTEXT_WINDOW = 5 # 2 to left, 2 to right
cropped_train_data = crop_sentence_length(train_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE)
cropped_val_data = crop_sentence_length(val_data, max_sentence_length =  MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_test_data = crop_sentence_length(test_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_future_data = crop_sentence_length(data_test, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
raw_token_pytorch_map = get_raw_token_pytorch_map(data = cropped_train_data, min_frequency = MIN_FREQUENCY)
train_x, train_y, val_x, val_y, test_x, test_y, train_dom, val_dom, test_dom, future_x = Data_Factory(cropped_train_data, \
                                                              cropped_val_data, \
                                                                cropped_test_data, \
                                                                    cropped_future_data, \
                                                                        MAX_SENTENCE_LENGTH, \
                                                                            raw_token_pytorch_map, \
                                                                                CLS=CLS, \
                                                                                    low_freq_special_token=LOW_FREQ_TOKEN, \
                                                                                        pad_front=PAD_FRONT)
pos_prior, neg_prior = get_distribution(train_y)
print('class prior:', pos_prior, neg_prior)
pos_dom_prior, neg_dom_prior = get_distribution(train_dom)
print('domain prior:', pos_dom_prior, neg_dom_prior)
dom1_pos_prior, dom1_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label1])
print('dom1 class prior:', dom1_pos_prior, dom1_neg_prior)
dom2_pos_prior, dom2_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label2])
print('dom2 class prior:', dom2_pos_prior, dom2_neg_prior)
pretrain_x, pretrain_y, pretrain_mask, pretrain_dom, preval_x, preval_y, preval_mask, preval_dom = BERT_pretrain_DataFactory(train_data, val_data, SEED, raw_token_pytorch_map, MAX_SENTENCE_LENGTH)

print('---')

class BERT_config:
    # ----------------- architectual hyperparameters ----------------- #
    d_model = 256
    d_ff = 1024 # = 4* d_model
    n_heads = 8
    dropout = 0.1
    e_layers = 8
    embedding_aggregation = 'cls' # TODO
    n_mlp_layers = 1
    res_learning = False
    activation = nn.ReLU()
    mask_flag = False # causal mask
    train_embedding = True
    # ----------------- optimisation hyperparameters ----------------- #
    random_state = SEED
    batch_size = 8
    epochs = 32
    lr = 1e-5
    patience = 10
    # loss = nn.CrossEntropyLoss()
    loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    # validation_loss = nn.CrossEntropyLoss()
    validation_loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    pretrain_loss = nn.CrossEntropyLoss()
    pretrain_validation_loss = nn.CrossEntropyLoss()
    regularisation_loss = None
    scheduler = False
    grad_clip = False
    # ----------------- operation hyperparameters ----------------- #
    d_output = 2
    seq_len = MAX_SENTENCE_LENGTH
    n_unique_tokens = len(raw_token_pytorch_map)
    # ----------------- saving hyperparameters ----------------- #
    rootpath = home_directory + './'
    saving_address = home_directory + f'./results/'
    name = f'BERT_Classifier'

model = BERT(BERT_config) # initialise the model

# train the model (all cells except this one will print training log and evaluation at each batch)
# pretrain_best_epoch = model.fit_pretrain(pretrain_x, pretrain_y, pretrain_dom, pretrain_mask, preval_x, preval_y, preval_dom, preval_mask)
# print()

# as model automatically saves best epoch, will now load the best epoch and evaluate on test set
# model.load()
# model.eval_pretrain(preval_x, preval_y, preval_dom, preval_mask, pretrain_best_epoch, evaluation_mode = True)

# train the model (all cells except this one will print training log and evaluation at each batch)
best_epoch = model.fit(train_x, train_y, train_dom, val_x, val_y, val_dom)
print()

# as model automatically saves best epoch, will now load the best epoch and evaluate on test set
model.load()
model.eval(val_x, val_y, val_dom, best_epoch, evaluation_mode = True)
model.eval(test_x, test_y, test_dom, best_epoch, evaluation_mode = True)

In [None]:
# model.load()
# model.eval(val_x, val_y, val_dom, best_epoch, evaluation_mode = True)
# model.eval(test_x, test_y, test_dom, best_epoch, evaluation_mode = True)

# EXPERIMENT_NAME = '8bert_1mlpc_1mlpd_256d1024_256d40_8h_0.1_embed_4batch_bwce'

# future_pred_y = model.predict(future_x)

# future_pred_y = [1 if x[1] > x[0] else 0 for x in future_pred_y]

# predictions = pd.DataFrame({'id': range(len(future_pred_y)), 'class': future_pred_y})
# predictions.to_csv(home_directory + f'predictions/{EXPERIMENT_NAME}_classification.csv', index=False)

Epoch 20 Val   | Loss:  0.2023 | Accuracy:  0.8459| F1:  0.7007 | Balanced Accuracy:  0.8337 | Dom Avg Accuracy:  0.7931 |
                Domain 1 Accuracy:  0.7587| Domain 1 F1:  0.7774 | Domain 1 Balanced Accuracy:  0.7587 | 
                Domain 2 Accuracy:  0.8795| Domain 2 F1:  0.5927 | Domain 2 Balanced Accuracy:  0.8275
Epoch 20 Val   | Loss:  0.2083 | Accuracy:  0.8385| F1:  0.6845 | Balanced Accuracy:  0.8206 | Dom Avg Accuracy:  0.7759 |
                Domain 1 Accuracy:  0.7320| Domain 1 F1:  0.7528 | Domain 1 Balanced Accuracy:  0.7320 | 
                Domain 2 Accuracy:  0.8795| Domain 2 F1:  0.5870 | Domain 2 Balanced Accuracy:  0.8198


---
#### 2. DANN

In [None]:
# BERT - DANN CELoss

MAX_SENTENCE_LENGTH = 256
MIN_FREQUENCY = 40 # because 40 is statistical sample requirement
MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE = False
LOW_FREQ_TOKEN = False
CLS = True
PAD_FRONT = False
W2V_CONTEXT_WINDOW = 5 # 2 to left, 2 to right
cropped_train_data = crop_sentence_length(train_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE)
cropped_val_data = crop_sentence_length(val_data, max_sentence_length =  MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_test_data = crop_sentence_length(test_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_future_data = crop_sentence_length(data_test, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
raw_token_pytorch_map = get_raw_token_pytorch_map(data = cropped_train_data, min_frequency = MIN_FREQUENCY)
train_x, train_y, val_x, val_y, test_x, test_y, train_dom, val_dom, test_dom, future_x = Data_Factory(cropped_train_data, \
                                                              cropped_val_data, \
                                                                cropped_test_data, \
                                                                    cropped_future_data, \
                                                                        MAX_SENTENCE_LENGTH, \
                                                                            raw_token_pytorch_map, \
                                                                                CLS=CLS, \
                                                                                    low_freq_special_token=LOW_FREQ_TOKEN, \
                                                                                        pad_front=PAD_FRONT)
pos_prior, neg_prior = get_distribution(train_y)
print('class prior:', pos_prior, neg_prior)
pos_dom_prior, neg_dom_prior = get_distribution(train_dom)
print('domain prior:', pos_dom_prior, neg_dom_prior)
dom1_pos_prior, dom1_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label1])
print('dom1 class prior:', dom1_pos_prior, dom1_neg_prior)
dom2_pos_prior, dom2_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label2])
print('dom2 class prior:', dom2_pos_prior, dom2_neg_prior)
pretrain_x, pretrain_y, pretrain_mask, pretrain_dom, preval_x, preval_y, preval_mask, preval_dom = BERT_pretrain_DataFactory(train_data, val_data, SEED, raw_token_pytorch_map, MAX_SENTENCE_LENGTH)

print('---')

class BERT_DANN_config:
    # ----------------- architectual hyperparameters ----------------- #
    d_model = 256
    d_ff = 1024 # = 4* d_model
    n_heads = 8
    dropout = 0.1
    e_layers = 8
    embedding_aggregation = 'cls' # TODO
    n_mlp_clf_layers = 1
    n_mlp_dom_layers = 1
    res_learning = False
    activation = nn.ReLU()
    mask_flag = False # causal mask
    train_embedding = True
    # ----------------- optimisation hyperparameters ----------------- #
    random_state = SEED
    batch_size = 8
    epochs = 32
    lr = 1e-5
    patience = 10
    # loss = nn.CrossEntropyLoss()
    loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    # validation_loss = nn.CrossEntropyLoss()
    validation_loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    domain_loss = nn.CrossEntropyLoss()
    pretrain_loss = nn.CrossEntropyLoss()
    pretrain_validation_loss = nn.CrossEntropyLoss()
    alpha = 0.1
    gradient_reversal_every_n_epoch = 1
    regularisation_loss = None
    scheduler = False
    grad_clip = False
    # ----------------- operation hyperparameters ----------------- #
    d_output = 2
    seq_len = MAX_SENTENCE_LENGTH
    n_unique_tokens = len(raw_token_pytorch_map)
    # ----------------- saving hyperparameters ----------------- #
    rootpath = home_directory + './'
    saving_address = home_directory + f'./results/'
    name = f'BERT_DANN'

model = BERT_DANN(BERT_DANN_config) # initialise the model

# train the model (all cells except this one will print training log and evaluation at each batch)
best_epoch = model.fit(train_x, train_y, train_dom, val_x, val_y, val_dom)
print()

# as model automatically saves best epoch, will now load the best epoch and evaluate on test set
model.load()
model.eval(val_x, val_y, val_dom, best_epoch, evaluation_mode = True)
model.eval(test_x, test_y, test_dom, best_epoch, evaluation_mode = True)

In [None]:
# BERT - DANN WCELoss

MAX_SENTENCE_LENGTH = 256
MIN_FREQUENCY = 40 # because 40 is statistical sample requirement
MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE = False
LOW_FREQ_TOKEN = False
CLS = True
PAD_FRONT = False
W2V_CONTEXT_WINDOW = 5 # 2 to left, 2 to right
cropped_train_data = crop_sentence_length(train_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE)
cropped_val_data = crop_sentence_length(val_data, max_sentence_length =  MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_test_data = crop_sentence_length(test_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_future_data = crop_sentence_length(data_test, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
raw_token_pytorch_map = get_raw_token_pytorch_map(data = cropped_train_data, min_frequency = MIN_FREQUENCY)
train_x, train_y, val_x, val_y, test_x, test_y, train_dom, val_dom, test_dom, future_x = Data_Factory(cropped_train_data, \
                                                              cropped_val_data, \
                                                                cropped_test_data, \
                                                                    cropped_future_data, \
                                                                        MAX_SENTENCE_LENGTH, \
                                                                            raw_token_pytorch_map, \
                                                                                CLS=CLS, \
                                                                                    low_freq_special_token=LOW_FREQ_TOKEN, \
                                                                                        pad_front=PAD_FRONT)
pos_prior, neg_prior = get_distribution(train_y)
print('class prior:', pos_prior, neg_prior)
pos_dom_prior, neg_dom_prior = get_distribution(train_dom)
print('domain prior:', pos_dom_prior, neg_dom_prior)
dom1_pos_prior, dom1_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label1])
print('dom1 class prior:', dom1_pos_prior, dom1_neg_prior)
dom2_pos_prior, dom2_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label2])
print('dom2 class prior:', dom2_pos_prior, dom2_neg_prior)
pretrain_x, pretrain_y, pretrain_mask, pretrain_dom, preval_x, preval_y, preval_mask, preval_dom = BERT_pretrain_DataFactory(train_data, val_data, SEED, raw_token_pytorch_map, MAX_SENTENCE_LENGTH)

print('---')

class BERT_DANN_config:
    # ----------------- architectual hyperparameters ----------------- #
    d_model = 256
    d_ff = 1024 # = 4* d_model
    n_heads = 8
    dropout = 0.1
    e_layers = 8
    embedding_aggregation = 'cls' # TODO
    n_mlp_clf_layers = 1
    n_mlp_dom_layers = 1
    res_learning = False
    activation = nn.ReLU()
    mask_flag = False # causal mask
    train_embedding = True
    # ----------------- optimisation hyperparameters ----------------- #
    random_state = SEED
    batch_size = 8
    epochs = 32
    lr = 1e-5
    patience = 10
    # loss = nn.CrossEntropyLoss()
    loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    # validation_loss = nn.CrossEntropyLoss()
    validation_loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    pretrain_loss = nn.CrossEntropyLoss()
    pretrain_validation_loss = nn.CrossEntropyLoss()
    domain_loss = nn.CrossEntropyLoss()
    alpha = 0.1
    gradient_reversal_every_n_epoch = 1
    regularisation_loss = None
    scheduler = False
    grad_clip = False
    # ----------------- operation hyperparameters ----------------- #
    d_output = 2
    seq_len = MAX_SENTENCE_LENGTH
    n_unique_tokens = len(raw_token_pytorch_map)
    # ----------------- saving hyperparameters ----------------- #
    rootpath = home_directory + './'
    saving_address = home_directory + f'./results/'
    name = f'BERT_DANN'

model = BERT_DANN(BERT_DANN_config) # initialise the model

# train the model (all cells except this one will print training log and evaluation at each batch)
best_epoch = model.fit(train_x, train_y, train_dom, val_x, val_y, val_dom)
print()

# as model automatically saves best epoch, will now load the best epoch and evaluate on test set
model.load()
model.eval(val_x, val_y, val_dom, best_epoch, evaluation_mode = True)
model.eval(test_x, test_y, test_dom, best_epoch, evaluation_mode = True)

In [None]:
# EXPERIMENT_NAME = '8bert_1mlpc_1mlpd_256d1024_256d40_8h_0.1_embed_4batch_bwce_low_freq'

# future_pred_y, future_pred_dom = model.predict(future_x)

# future_pred_y = [1 if x[1] > x[0] else 0 for x in future_pred_y]

# predictions = pd.DataFrame({'id': range(len(future_pred_y)), 'class': future_pred_y})
# predictions.to_csv(home_directory + f'predictions/{EXPERIMENT_NAME}_classification.csv', index=False)

100%|██████████| 12600/12600 [00:00<00:00, 15485.68it/s]
100%|██████████| 2700/2700 [00:00<00:00, 25772.33it/s]
100%|██████████| 2700/2700 [00:00<00:00, 3127.14it/s]
100%|██████████| 4000/4000 [00:00<00:00, 41878.65it/s]
100%|██████████| 12600/12600 [00:01<00:00, 12264.92it/s]
100%|██████████| 12600/12600 [00:00<00:00, 37174.52it/s]
100%|██████████| 2700/2700 [00:00<00:00, 40789.31it/s]
100%|██████████| 2700/2700 [00:00<00:00, 35825.73it/s]
100%|██████████| 4000/4000 [00:00<00:00, 39010.34it/s]


class prior: 0.2222222222222222 0.7777777777777778
domain prior: 0.7222222222222222 0.2777777777777778
dom1 class prior: 0.5 0.5
dom2 class prior: 0.11538461538461539 0.8846153846153846


100%|██████████| 12600/12600 [00:02<00:00, 4654.85it/s]
100%|██████████| 2700/2700 [00:00<00:00, 6566.89it/s]


---
Epoch 21 Val   | Classification Loss:  0.1625 | Accuracy:  0.8574| F1:  0.7150 | Balanced Accuracy:  0.8387 | Dom Avg Accuracy:  0.8010 |
                            Domain Loss:  0.5397 | Domain Accuracy:  0.7581 |  
                            Domain 1 Accuracy:  0.7920| Domain 1 F1:  0.8050 | Domain 1 Balanced Accuracy:  0.7920 |  
                            Domain 2 Accuracy:  0.8826| Domain 2 F1:  0.5844 | Domain 2 Balanced Accuracy:  0.8100
Epoch 21 Val   | Classification Loss:  0.1658 | Accuracy:  0.8589| F1:  0.7167 | Balanced Accuracy:  0.8390 | Dom Avg Accuracy:  0.8013 |
                            Domain Loss:  0.5434 | Domain Accuracy:  0.7607 |  
                            Domain 1 Accuracy:  0.7667| Domain 1 F1:  0.7804 | Domain 1 Balanced Accuracy:  0.7667 |  
                            Domain 2 Accuracy:  0.8944| Domain 2 F1:  0.6241 | Domain 2 Balanced Accuracy:  0.8359


---
#### 3. DCE_DANN

In [None]:
# BERT DBCE - celoss expr - just bal domain

MAX_SENTENCE_LENGTH = 256
MIN_FREQUENCY = 40 # because 40 is statistical sample requirement
MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE = False
LOW_FREQ_TOKEN = False
CLS = True
PAD_FRONT = False
W2V_CONTEXT_WINDOW = 5 # 2 to left, 2 to right
cropped_train_data = crop_sentence_length(train_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE)
cropped_val_data = crop_sentence_length(val_data, max_sentence_length =  MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_test_data = crop_sentence_length(test_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_future_data = crop_sentence_length(data_test, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
raw_token_pytorch_map = get_raw_token_pytorch_map(data = cropped_train_data, min_frequency = MIN_FREQUENCY)
train_x, train_y, val_x, val_y, test_x, test_y, train_dom, val_dom, test_dom, future_x = Data_Factory(cropped_train_data, \
                                                              cropped_val_data, \
                                                                cropped_test_data, \
                                                                    cropped_future_data, \
                                                                        MAX_SENTENCE_LENGTH, \
                                                                            raw_token_pytorch_map, \
                                                                                CLS=CLS, \
                                                                                    low_freq_special_token=LOW_FREQ_TOKEN, \
                                                                                        pad_front=PAD_FRONT)
pos_prior, neg_prior = get_distribution(train_y)
print('class prior:', pos_prior, neg_prior)
pos_dom_prior, neg_dom_prior = get_distribution(train_dom)
print('domain prior:', pos_dom_prior, neg_dom_prior)
dom1_pos_prior, dom1_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label1])
print('dom1 class prior:', dom1_pos_prior, dom1_neg_prior)
dom2_pos_prior, dom2_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label2])
print('dom2 class prior:', dom2_pos_prior, dom2_neg_prior)
pretrain_x, pretrain_y, pretrain_mask, pretrain_dom, preval_x, preval_y, preval_mask, preval_dom = BERT_pretrain_DataFactory(train_data, val_data, SEED, raw_token_pytorch_map, MAX_SENTENCE_LENGTH)

print('---')

class BERT_DCE_DANN_config:
    # ----------------- architectual hyperparameters ----------------- #
    d_model = 256
    d_ff = 1024 # = 4* d_model
    n_heads = 8
    dropout = 0.1
    e_layers = 8
    embedding_aggregation = 'cls' # TODO
    n_mlp_clf_layers = 1
    n_mlp_dom_layers = 1
    res_learning = False
    activation = nn.ReLU()
    mask_flag = False # causal mask
    train_embedding = True
    # ----------------- optimisation hyperparameters ----------------- #
    random_state = SEED
    batch_size = 8
    epochs = 32
    lr = 1e-5
    patience = 10
    # loss = nn.CrossEntropyLoss()
    loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    # validation_loss = nn.CrossEntropyLoss()
    validation_loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    domain_1_loss = nn.CrossEntropyLoss()
    # domain_1_loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([dom1_pos_prior, dom1_neg_prior]))
    domain_1_validation_loss = nn.CrossEntropyLoss()
    # domain_1_validation_loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([dom1_pos_prior, dom1_neg_prior]))
    domain_2_loss = nn.CrossEntropyLoss()
    # domain_2_loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([dom2_pos_prior, dom2_neg_prior]))
    domain_2_validation_loss = nn.CrossEntropyLoss()
    # domain_2_validation_loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([dom2_pos_prior, dom2_neg_prior]))
    domain_prior = [pos_dom_prior, neg_dom_prior]
    pretrain_loss = nn.CrossEntropyLoss()
    pretrain_validation_loss = nn.CrossEntropyLoss()
    domain_loss = nn.CrossEntropyLoss()
    alpha = 0.1
    gradient_reversal_every_n_epoch = 1
    regularisation_loss = None
    scheduler = False
    grad_clip = False
    # ----------------- operation hyperparameters ----------------- #
    d_output = 2
    seq_len = MAX_SENTENCE_LENGTH
    n_unique_tokens = len(raw_token_pytorch_map)
    # ----------------- saving hyperparameters ----------------- #
    rootpath = home_directory + './'
    saving_address = home_directory + f'./results/'
    name = f'BERT_Classifier'

model = BERT_DCE_DANN(BERT_DCE_DANN_config) # initialise the model

# train the model (all cells except this one will print training log and evaluation at each batch)
best_epoch = model.fit(train_x, train_y, train_dom, val_x, val_y, val_dom)
print()

# as model automatically saves best epoch, will now load the best epoch and evaluate on test set
model.load()
model.eval(val_x, val_y, val_dom, best_epoch, evaluation_mode = True)
model.eval(test_x, test_y, test_dom, best_epoch, evaluation_mode = True)

In [None]:
# BERT DBCE - celoss expr - just bal dom label

MAX_SENTENCE_LENGTH = 256
MIN_FREQUENCY = 40 # because 40 is statistical sample requirement
MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE = False
LOW_FREQ_TOKEN = False
CLS = True
PAD_FRONT = False
W2V_CONTEXT_WINDOW = 5 # 2 to left, 2 to right
cropped_train_data = crop_sentence_length(train_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE)
cropped_val_data = crop_sentence_length(val_data, max_sentence_length =  MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_test_data = crop_sentence_length(test_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_future_data = crop_sentence_length(data_test, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
raw_token_pytorch_map = get_raw_token_pytorch_map(data = cropped_train_data, min_frequency = MIN_FREQUENCY)
train_x, train_y, val_x, val_y, test_x, test_y, train_dom, val_dom, test_dom, future_x = Data_Factory(cropped_train_data, \
                                                              cropped_val_data, \
                                                                cropped_test_data, \
                                                                    cropped_future_data, \
                                                                        MAX_SENTENCE_LENGTH, \
                                                                            raw_token_pytorch_map, \
                                                                                CLS=CLS, \
                                                                                    low_freq_special_token=LOW_FREQ_TOKEN, \
                                                                                        pad_front=PAD_FRONT)
pos_prior, neg_prior = get_distribution(train_y)
print('class prior:', pos_prior, neg_prior)
pos_dom_prior, neg_dom_prior = get_distribution(train_dom)
print('domain prior:', pos_dom_prior, neg_dom_prior)
dom1_pos_prior, dom1_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label1])
print('dom1 class prior:', dom1_pos_prior, dom1_neg_prior)
dom2_pos_prior, dom2_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label2])
print('dom2 class prior:', dom2_pos_prior, dom2_neg_prior)
pretrain_x, pretrain_y, pretrain_mask, pretrain_dom, preval_x, preval_y, preval_mask, preval_dom = BERT_pretrain_DataFactory(train_data, val_data, SEED, raw_token_pytorch_map, MAX_SENTENCE_LENGTH)

print('---')

class BERT_DCE_DANN_config:
    # ----------------- architectual hyperparameters ----------------- #
    d_model = 256
    d_ff = 1024 # = 4* d_model
    n_heads = 8
    dropout = 0.1
    e_layers = 8
    embedding_aggregation = 'cls' # TODO
    n_mlp_clf_layers = 1
    n_mlp_dom_layers = 1
    res_learning = False
    activation = nn.ReLU()
    mask_flag = False # causal mask
    train_embedding = True
    # ----------------- optimisation hyperparameters ----------------- #
    random_state = SEED
    batch_size = 8
    epochs = 32
    lr = 1e-5
    patience = 10
    # loss = nn.CrossEntropyLoss()
    loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    # validation_loss = nn.CrossEntropyLoss()
    validation_loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    # domain_1_loss = nn.CrossEntropyLoss()
    domain_1_loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([dom1_pos_prior, dom1_neg_prior]))
    # domain_1_validation_loss = nn.CrossEntropyLoss()
    domain_1_validation_loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([dom1_pos_prior, dom1_neg_prior]))
    # domain_2_loss = nn.CrossEntropyLoss()
    domain_2_loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([dom2_pos_prior, dom2_neg_prior]))
    # domain_2_validation_loss = nn.CrossEntropyLoss()
    domain_2_validation_loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([dom2_pos_prior, dom2_neg_prior]))
    domain_prior = [0.5, 0.5]
    pretrain_loss = nn.CrossEntropyLoss()
    pretrain_validation_loss = nn.CrossEntropyLoss()
    domain_loss = nn.CrossEntropyLoss()
    alpha = 0.1
    gradient_reversal_every_n_epoch = 1
    regularisation_loss = None
    scheduler = False
    grad_clip = False
    # ----------------- operation hyperparameters ----------------- #
    d_output = 2
    seq_len = MAX_SENTENCE_LENGTH
    n_unique_tokens = len(raw_token_pytorch_map)
    # ----------------- saving hyperparameters ----------------- #
    rootpath = home_directory + './'
    saving_address = home_directory + f'./results/'
    name = f'BERT_Classifier'

model = BERT_DCE_DANN(BERT_DCE_DANN_config) # initialise the model

# train the model (all cells except this one will print training log and evaluation at each batch)
best_epoch = model.fit(train_x, train_y, train_dom, val_x, val_y, val_dom)
print()

# as model automatically saves best epoch, will now load the best epoch and evaluate on test set
model.load()
model.eval(val_x, val_y, val_dom, best_epoch, evaluation_mode = True)
model.eval(test_x, test_y, test_dom, best_epoch, evaluation_mode = True)

In [None]:
# BERT DBCE - celoss expr - bal both domain and label

MAX_SENTENCE_LENGTH = 256
MIN_FREQUENCY = 40 # because 40 is statistical sample requirement
MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE = False
LOW_FREQ_TOKEN = False
CLS = True
PAD_FRONT = False
W2V_CONTEXT_WINDOW = 5 # 2 to left, 2 to right
cropped_train_data = crop_sentence_length(train_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE)
cropped_val_data = crop_sentence_length(val_data, max_sentence_length =  MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_test_data = crop_sentence_length(test_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_future_data = crop_sentence_length(data_test, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
raw_token_pytorch_map = get_raw_token_pytorch_map(data = cropped_train_data, min_frequency = MIN_FREQUENCY)
train_x, train_y, val_x, val_y, test_x, test_y, train_dom, val_dom, test_dom, future_x = Data_Factory(cropped_train_data, \
                                                              cropped_val_data, \
                                                                cropped_test_data, \
                                                                    cropped_future_data, \
                                                                        MAX_SENTENCE_LENGTH, \
                                                                            raw_token_pytorch_map, \
                                                                                CLS=CLS, \
                                                                                    low_freq_special_token=LOW_FREQ_TOKEN, \
                                                                                        pad_front=PAD_FRONT)
pos_prior, neg_prior = get_distribution(train_y)
print('class prior:', pos_prior, neg_prior)
pos_dom_prior, neg_dom_prior = get_distribution(train_dom)
print('domain prior:', pos_dom_prior, neg_dom_prior)
dom1_pos_prior, dom1_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label1])
print('dom1 class prior:', dom1_pos_prior, dom1_neg_prior)
dom2_pos_prior, dom2_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label2])
print('dom2 class prior:', dom2_pos_prior, dom2_neg_prior)
pretrain_x, pretrain_y, pretrain_mask, pretrain_dom, preval_x, preval_y, preval_mask, preval_dom = BERT_pretrain_DataFactory(train_data, val_data, SEED, raw_token_pytorch_map, MAX_SENTENCE_LENGTH)

print('---')

class BERT_DCE_DANN_config:
    # ----------------- architectual hyperparameters ----------------- #
    d_model = 256
    d_ff = 1024 # = 4* d_model
    n_heads = 8
    dropout = 0.1
    e_layers = 8
    embedding_aggregation = 'cls' # TODO
    n_mlp_clf_layers = 1
    n_mlp_dom_layers = 1
    res_learning = False
    activation = nn.ReLU()
    mask_flag = False # causal mask
    train_embedding = True
    # ----------------- optimisation hyperparameters ----------------- #
    random_state = SEED
    batch_size = 8
    epochs = 32
    lr = 1e-5
    patience = 10
    # loss = nn.CrossEntropyLoss()
    loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    # validation_loss = nn.CrossEntropyLoss()
    validation_loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    # domain_1_loss = nn.CrossEntropyLoss()
    domain_1_loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([dom1_pos_prior, dom1_neg_prior]))
    # domain_1_validation_loss = nn.CrossEntropyLoss()
    domain_1_validation_loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([dom1_pos_prior, dom1_neg_prior]))
    # domain_2_loss = nn.CrossEntropyLoss()
    domain_2_loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([dom2_pos_prior, dom2_neg_prior]))
    # domain_2_validation_loss = nn.CrossEntropyLoss()
    domain_2_validation_loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([dom2_pos_prior, dom2_neg_prior]))
    domain_prior = [pos_dom_prior, neg_dom_prior]
    pretrain_loss = nn.CrossEntropyLoss()
    pretrain_validation_loss = nn.CrossEntropyLoss()
    domain_loss = nn.CrossEntropyLoss()
    alpha = 0.1
    gradient_reversal_every_n_epoch = 1
    regularisation_loss = None
    scheduler = False
    grad_clip = False
    # ----------------- operation hyperparameters ----------------- #
    d_output = 2
    seq_len = MAX_SENTENCE_LENGTH
    n_unique_tokens = len(raw_token_pytorch_map)
    # ----------------- saving hyperparameters ----------------- #
    rootpath = home_directory + './'
    saving_address = home_directory + f'./results/'
    name = f'BERT_Classifier'

model = BERT_DCE_DANN(BERT_DCE_DANN_config) # initialise the model

# train the model (all cells except this one will print training log and evaluation at each batch)
best_epoch = model.fit(train_x, train_y, train_dom, val_x, val_y, val_dom)
print()

# as model automatically saves best epoch, will now load the best epoch and evaluate on test set
model.load()
model.eval(val_x, val_y, val_dom, best_epoch, evaluation_mode = True)
model.eval(test_x, test_y, test_dom, best_epoch, evaluation_mode = True)

100%|██████████| 12600/12600 [00:00<00:00, 62878.33it/s]
100%|██████████| 2700/2700 [00:00<00:00, 5744.28it/s]
100%|██████████| 2700/2700 [00:00<00:00, 135326.00it/s]
100%|██████████| 4000/4000 [00:00<00:00, 119904.92it/s]
100%|██████████| 12600/12600 [00:00<00:00, 42686.88it/s]
100%|██████████| 12600/12600 [00:00<00:00, 53734.31it/s]
100%|██████████| 2700/2700 [00:00<00:00, 52628.84it/s]
100%|██████████| 2700/2700 [00:00<00:00, 54578.59it/s]
100%|██████████| 4000/4000 [00:00<00:00, 53698.43it/s]


class prior: 0.2222222222222222 0.7777777777777778
domain prior: 0.7222222222222222 0.2777777777777778
dom1 class prior: 0.5 0.5
dom2 class prior: 0.11538461538461539 0.8846153846153846


100%|██████████| 12600/12600 [00:03<00:00, 4084.71it/s]
100%|██████████| 2700/2700 [00:00<00:00, 6401.98it/s]


---


100%|█████████▉| 1575/1576 [01:04<00:00, 24.47it/s]


Epoch 1 Train | Loss:  0.0962 | Accuracy:  0.5616| F1:  0.3621 | Balanced Accuracy:  0.5610 | Dom Avg Accuracy:  0.5796 |
                    Domain 1 Accuracy:  0.5297| Domain 1 F1:  0.5024 | Domain 1 Balanced Accuracy:  0.5297 | 
                    Domain 2 Accuracy:  0.5738| Domain 2 F1:  0.2754 | Domain 2 Balanced Accuracy:  0.6295
Epoch 1 Val   | Loss:  0.2324 | Accuracy:  0.6689| F1:  0.3568 | Balanced Accuracy:  0.5776 | Dom Avg Accuracy:  0.6131 |
                Domain 1 Accuracy:  0.5547| Domain 1 F1:  0.3949 | Domain 1 Balanced Accuracy:  0.5547 | 
                Domain 2 Accuracy:  0.7128| Domain 2 F1:  0.3317 | Domain 2 Balanced Accuracy:  0.6715


100%|█████████▉| 1575/1576 [01:04<00:00, 24.56it/s]


Epoch 2 Train | Loss:  0.0917 | Accuracy:  0.6575| F1:  0.4320 | Balanced Accuracy:  0.6320 | Dom Avg Accuracy:  0.6552 |
                    Domain 1 Accuracy:  0.5820| Domain 1 F1:  0.5282 | Domain 1 Balanced Accuracy:  0.5820 | 
                    Domain 2 Accuracy:  0.6865| Domain 2 F1:  0.3656 | Domain 2 Balanced Accuracy:  0.7284
Epoch 2 Val   | Loss:  0.2214 | Accuracy:  0.6959| F1:  0.4263 | Balanced Accuracy:  0.6289 | Dom Avg Accuracy:  0.6561 |
                Domain 1 Accuracy:  0.5907| Domain 1 F1:  0.4892 | Domain 1 Balanced Accuracy:  0.5907 | 
                Domain 2 Accuracy:  0.7364| Domain 2 F1:  0.3807 | Domain 2 Balanced Accuracy:  0.7215


100%|█████████▉| 1575/1576 [01:04<00:00, 24.30it/s]


Epoch 3 Train | Loss:  0.0884 | Accuracy:  0.6679| F1:  0.4532 | Balanced Accuracy:  0.6505 | Dom Avg Accuracy:  0.6795 |
                    Domain 1 Accuracy:  0.6157| Domain 1 F1:  0.5663 | Domain 1 Balanced Accuracy:  0.6157 | 
                    Domain 2 Accuracy:  0.6879| Domain 2 F1:  0.3761 | Domain 2 Balanced Accuracy:  0.7433
Epoch 3 Val   | Loss:  0.2194 | Accuracy:  0.6996| F1:  0.4380 | Balanced Accuracy:  0.6379 | Dom Avg Accuracy:  0.6656 |
                Domain 1 Accuracy:  0.6013| Domain 1 F1:  0.5074 | Domain 1 Balanced Accuracy:  0.6013 | 
                Domain 2 Accuracy:  0.7374| Domain 2 F1:  0.3876 | Domain 2 Balanced Accuracy:  0.7299


100%|█████████▉| 1575/1576 [01:04<00:00, 24.56it/s]


Epoch 4 Train | Loss:  0.0849 | Accuracy:  0.6648| F1:  0.4687 | Balanced Accuracy:  0.6650 | Dom Avg Accuracy:  0.6941 |
                    Domain 1 Accuracy:  0.6414| Domain 1 F1:  0.6094 | Domain 1 Balanced Accuracy:  0.6414 | 
                    Domain 2 Accuracy:  0.6737| Domain 2 F1:  0.3732 | Domain 2 Balanced Accuracy:  0.7469
Epoch 4 Val   | Loss:  0.2117 | Accuracy:  0.5904| F1:  0.4459 | Balanced Accuracy:  0.6444 | Dom Avg Accuracy:  0.6630 |
                Domain 1 Accuracy:  0.6200| Domain 1 F1:  0.6360 | Domain 1 Balanced Accuracy:  0.6200 | 
                Domain 2 Accuracy:  0.5790| Domain 2 F1:  0.3232 | Domain 2 Balanced Accuracy:  0.7060


100%|█████████▉| 1575/1576 [01:04<00:00, 24.50it/s]


Epoch 5 Train | Loss:  0.0815 | Accuracy:  0.6787| F1:  0.4934 | Balanced Accuracy:  0.6877 | Dom Avg Accuracy:  0.7144 |
                    Domain 1 Accuracy:  0.6771| Domain 1 F1:  0.6572 | Domain 1 Balanced Accuracy:  0.6771 | 
                    Domain 2 Accuracy:  0.6793| Domain 2 F1:  0.3784 | Domain 2 Balanced Accuracy:  0.7517
Epoch 5 Val   | Loss:  0.2187 | Accuracy:  0.5293| F1:  0.4621 | Balanced Accuracy:  0.6652 | Dom Avg Accuracy:  0.6724 |
                Domain 1 Accuracy:  0.6533| Domain 1 F1:  0.7168 | Domain 1 Balanced Accuracy:  0.6533 | 
                Domain 2 Accuracy:  0.4815| Domain 2 F1:  0.3003 | Domain 2 Balanced Accuracy:  0.6915


100%|█████████▉| 1575/1576 [01:04<00:00, 24.53it/s]


Epoch 6 Train | Loss:  0.0790 | Accuracy:  0.6821| F1:  0.5123 | Balanced Accuracy:  0.7068 | Dom Avg Accuracy:  0.7266 |
                    Domain 1 Accuracy:  0.6960| Domain 1 F1:  0.6927 | Domain 1 Balanced Accuracy:  0.6960 | 
                    Domain 2 Accuracy:  0.6767| Domain 2 F1:  0.3809 | Domain 2 Balanced Accuracy:  0.7572
Epoch 6 Val   | Loss:  0.2062 | Accuracy:  0.7248| F1:  0.4976 | Balanced Accuracy:  0.6850 | Dom Avg Accuracy:  0.6989 |
                Domain 1 Accuracy:  0.6613| Domain 1 F1:  0.6186 | Domain 1 Balanced Accuracy:  0.6613 | 
                Domain 2 Accuracy:  0.7492| Domain 2 F1:  0.3985 | Domain 2 Balanced Accuracy:  0.7365


100%|█████████▉| 1575/1576 [01:04<00:00, 24.52it/s]


Epoch 7 Train | Loss:  0.0760 | Accuracy:  0.7031| F1:  0.5374 | Balanced Accuracy:  0.7292 | Dom Avg Accuracy:  0.7455 |
                    Domain 1 Accuracy:  0.7194| Domain 1 F1:  0.7197 | Domain 1 Balanced Accuracy:  0.7194 | 
                    Domain 2 Accuracy:  0.6968| Domain 2 F1:  0.3980 | Domain 2 Balanced Accuracy:  0.7715
Epoch 7 Val   | Loss:  0.1905 | Accuracy:  0.6778| F1:  0.5256 | Balanced Accuracy:  0.7226 | Dom Avg Accuracy:  0.7377 |
                Domain 1 Accuracy:  0.7013| Domain 1 F1:  0.7121 | Domain 1 Balanced Accuracy:  0.7013 | 
                Domain 2 Accuracy:  0.6687| Domain 2 F1:  0.3883 | Domain 2 Balanced Accuracy:  0.7741


100%|█████████▉| 1575/1576 [01:04<00:00, 24.51it/s]


Epoch 8 Train | Loss:  0.0732 | Accuracy:  0.7150| F1:  0.5511 | Balanced Accuracy:  0.7408 | Dom Avg Accuracy:  0.7538 |
                    Domain 1 Accuracy:  0.7351| Domain 1 F1:  0.7378 | Domain 1 Balanced Accuracy:  0.7351 | 
                    Domain 2 Accuracy:  0.7073| Domain 2 F1:  0.4032 | Domain 2 Balanced Accuracy:  0.7724
Epoch 8 Val   | Loss:  0.1841 | Accuracy:  0.6885| F1:  0.5336 | Balanced Accuracy:  0.7289 | Dom Avg Accuracy:  0.7456 |
                Domain 1 Accuracy:  0.7107| Domain 1 F1:  0.7178 | Domain 1 Balanced Accuracy:  0.7107 | 
                Domain 2 Accuracy:  0.6800| Domain 2 F1:  0.3965 | Domain 2 Balanced Accuracy:  0.7805


100%|█████████▉| 1575/1576 [01:04<00:00, 24.55it/s]


Epoch 9 Train | Loss:  0.0706 | Accuracy:  0.7234| F1:  0.5594 | Balanced Accuracy:  0.7472 | Dom Avg Accuracy:  0.7620 |
                    Domain 1 Accuracy:  0.7440| Domain 1 F1:  0.7444 | Domain 1 Balanced Accuracy:  0.7440 | 
                    Domain 2 Accuracy:  0.7155| Domain 2 F1:  0.4120 | Domain 2 Balanced Accuracy:  0.7800
Epoch 9 Val   | Loss:  0.1780 | Accuracy:  0.6852| F1:  0.5498 | Balanced Accuracy:  0.7494 | Dom Avg Accuracy:  0.7499 |
                Domain 1 Accuracy:  0.7200| Domain 1 F1:  0.7482 | Domain 1 Balanced Accuracy:  0.7200 | 
                Domain 2 Accuracy:  0.6718| Domain 2 F1:  0.3928 | Domain 2 Balanced Accuracy:  0.7797


100%|█████████▉| 1575/1576 [01:04<00:00, 24.33it/s]


Epoch 10 Train | Loss:  0.0682 | Accuracy:  0.7470| F1:  0.5832 | Balanced Accuracy:  0.7646 | Dom Avg Accuracy:  0.7731 |
                    Domain 1 Accuracy:  0.7577| Domain 1 F1:  0.7596 | Domain 1 Balanced Accuracy:  0.7577 | 
                    Domain 2 Accuracy:  0.7429| Domain 2 F1:  0.4320 | Domain 2 Balanced Accuracy:  0.7884
Epoch 10 Val   | Loss:  0.1841 | Accuracy:  0.6648| F1:  0.5366 | Balanced Accuracy:  0.7393 | Dom Avg Accuracy:  0.7493 |
                Domain 1 Accuracy:  0.7267| Domain 1 F1:  0.7527 | Domain 1 Balanced Accuracy:  0.7267 | 
                Domain 2 Accuracy:  0.6410| Domain 2 F1:  0.3772 | Domain 2 Balanced Accuracy:  0.7720


100%|█████████▉| 1575/1576 [01:04<00:00, 24.54it/s]


Epoch 11 Train | Loss:  0.0653 | Accuracy:  0.7506| F1:  0.5901 | Balanced Accuracy:  0.7711 | Dom Avg Accuracy:  0.7822 |
                    Domain 1 Accuracy:  0.7680| Domain 1 F1:  0.7693 | Domain 1 Balanced Accuracy:  0.7680 | 
                    Domain 2 Accuracy:  0.7440| Domain 2 F1:  0.4380 | Domain 2 Balanced Accuracy:  0.7965
Epoch 11 Val   | Loss:  0.1693 | Accuracy:  0.7344| F1:  0.5795 | Balanced Accuracy:  0.7662 | Dom Avg Accuracy:  0.7690 |
                Domain 1 Accuracy:  0.7493| Domain 1 F1:  0.7608 | Domain 1 Balanced Accuracy:  0.7493 | 
                Domain 2 Accuracy:  0.7287| Domain 2 F1:  0.4244 | Domain 2 Balanced Accuracy:  0.7887


100%|█████████▉| 1575/1576 [01:04<00:00, 24.54it/s]


Epoch 12 Train | Loss:  0.0630 | Accuracy:  0.7627| F1:  0.6044 | Balanced Accuracy:  0.7816 | Dom Avg Accuracy:  0.7918 |
                    Domain 1 Accuracy:  0.7794| Domain 1 F1:  0.7807 | Domain 1 Balanced Accuracy:  0.7794 | 
                    Domain 2 Accuracy:  0.7563| Domain 2 F1:  0.4507 | Domain 2 Balanced Accuracy:  0.8043
Epoch 12 Val   | Loss:  0.1726 | Accuracy:  0.7289| F1:  0.5679 | Balanced Accuracy:  0.7549 | Dom Avg Accuracy:  0.7672 |
                Domain 1 Accuracy:  0.7453| Domain 1 F1:  0.7484 | Domain 1 Balanced Accuracy:  0.7453 | 
                Domain 2 Accuracy:  0.7226| Domain 2 F1:  0.4214 | Domain 2 Balanced Accuracy:  0.7891


100%|█████████▉| 1575/1576 [01:04<00:00, 24.56it/s]


Epoch 13 Train | Loss:  0.0615 | Accuracy:  0.7693| F1:  0.6152 | Balanced Accuracy:  0.7910 | Dom Avg Accuracy:  0.8001 |
                    Domain 1 Accuracy:  0.7966| Domain 1 F1:  0.7994 | Domain 1 Balanced Accuracy:  0.7966 | 
                    Domain 2 Accuracy:  0.7588| Domain 2 F1:  0.4519 | Domain 2 Balanced Accuracy:  0.8036
Epoch 13 Val   | Loss:  0.1720 | Accuracy:  0.7222| F1:  0.5758 | Balanced Accuracy:  0.7673 | Dom Avg Accuracy:  0.7755 |
                Domain 1 Accuracy:  0.7560| Domain 1 F1:  0.7692 | Domain 1 Balanced Accuracy:  0.7560 | 
                Domain 2 Accuracy:  0.7092| Domain 2 F1:  0.4185 | Domain 2 Balanced Accuracy:  0.7951


100%|█████████▉| 1575/1576 [01:04<00:00, 24.51it/s]


Epoch 14 Train | Loss:  0.0595 | Accuracy:  0.7738| F1:  0.6208 | Balanced Accuracy:  0.7950 | Dom Avg Accuracy:  0.8066 |
                    Domain 1 Accuracy:  0.8003| Domain 1 F1:  0.8016 | Domain 1 Balanced Accuracy:  0.8003 | 
                    Domain 2 Accuracy:  0.7636| Domain 2 F1:  0.4613 | Domain 2 Balanced Accuracy:  0.8130
Epoch 14 Val   | Loss:  0.1801 | Accuracy:  0.6863| F1:  0.5618 | Balanced Accuracy:  0.7644 | Dom Avg Accuracy:  0.7621 |
                Domain 1 Accuracy:  0.7453| Domain 1 F1:  0.7776 | Domain 1 Balanced Accuracy:  0.7453 | 
                Domain 2 Accuracy:  0.6636| Domain 2 F1:  0.3892 | Domain 2 Balanced Accuracy:  0.7789


100%|█████████▉| 1575/1576 [01:04<00:00, 24.54it/s]


Epoch 15 Train | Loss:  0.0579 | Accuracy:  0.7797| F1:  0.6302 | Balanced Accuracy:  0.8029 | Dom Avg Accuracy:  0.8116 |
                    Domain 1 Accuracy:  0.8134| Domain 1 F1:  0.8168 | Domain 1 Balanced Accuracy:  0.8134 | 
                    Domain 2 Accuracy:  0.7667| Domain 2 F1:  0.4613 | Domain 2 Balanced Accuracy:  0.8098
Epoch 15 Val   | Loss:  0.1759 | Accuracy:  0.6996| F1:  0.5656 | Balanced Accuracy:  0.7640 | Dom Avg Accuracy:  0.7779 |
                Domain 1 Accuracy:  0.7600| Domain 1 F1:  0.7772 | Domain 1 Balanced Accuracy:  0.7600 | 
                Domain 2 Accuracy:  0.6764| Domain 2 F1:  0.4042 | Domain 2 Balanced Accuracy:  0.7958


100%|█████████▉| 1575/1576 [01:04<00:00, 24.52it/s]


Epoch 16 Train | Loss:  0.0557 | Accuracy:  0.7902| F1:  0.6435 | Balanced Accuracy:  0.8122 | Dom Avg Accuracy:  0.8225 |
                    Domain 1 Accuracy:  0.8254| Domain 1 F1:  0.8275 | Domain 1 Balanced Accuracy:  0.8254 | 
                    Domain 2 Accuracy:  0.7767| Domain 2 F1:  0.4749 | Domain 2 Balanced Accuracy:  0.8195
Epoch 16 Val   | Loss:  0.1698 | Accuracy:  0.6911| F1:  0.5688 | Balanced Accuracy:  0.7717 | Dom Avg Accuracy:  0.7619 |
                Domain 1 Accuracy:  0.7373| Domain 1 F1:  0.7754 | Domain 1 Balanced Accuracy:  0.7373 | 
                Domain 2 Accuracy:  0.6733| Domain 2 F1:  0.3974 | Domain 2 Balanced Accuracy:  0.7864


100%|█████████▉| 1575/1576 [01:04<00:00, 24.30it/s]


Epoch 17 Train | Loss:  0.0539 | Accuracy:  0.7983| F1:  0.6560 | Balanced Accuracy:  0.8223 | Dom Avg Accuracy:  0.8315 |
                    Domain 1 Accuracy:  0.8334| Domain 1 F1:  0.8365 | Domain 1 Balanced Accuracy:  0.8334 | 
                    Domain 2 Accuracy:  0.7848| Domain 2 F1:  0.4877 | Domain 2 Balanced Accuracy:  0.8295
Epoch 17 Val   | Loss:  0.1690 | Accuracy:  0.7107| F1:  0.5799 | Balanced Accuracy:  0.7777 | Dom Avg Accuracy:  0.7797 |
                Domain 1 Accuracy:  0.7560| Domain 1 F1:  0.7808 | Domain 1 Balanced Accuracy:  0.7560 | 
                Domain 2 Accuracy:  0.6933| Domain 2 F1:  0.4160 | Domain 2 Balanced Accuracy:  0.8035


100%|█████████▉| 1575/1576 [01:04<00:00, 24.51it/s]


Epoch 18 Train | Loss:  0.0519 | Accuracy:  0.8015| F1:  0.6600 | Balanced Accuracy:  0.8248 | Dom Avg Accuracy:  0.8355 |
                    Domain 1 Accuracy:  0.8411| Domain 1 F1:  0.8433 | Domain 1 Balanced Accuracy:  0.8411 | 
                    Domain 2 Accuracy:  0.7863| Domain 2 F1:  0.4891 | Domain 2 Balanced Accuracy:  0.8299
Epoch 18 Val   | Loss:  0.1729 | Accuracy:  0.7437| F1:  0.5831 | Balanced Accuracy:  0.7662 | Dom Avg Accuracy:  0.7745 |
                Domain 1 Accuracy:  0.7467| Domain 1 F1:  0.7507 | Domain 1 Balanced Accuracy:  0.7467 | 
                Domain 2 Accuracy:  0.7426| Domain 2 F1:  0.4410 | Domain 2 Balanced Accuracy:  0.8023


100%|█████████▉| 1575/1576 [01:04<00:00, 24.51it/s]


Epoch 19 Train | Loss:  0.0498 | Accuracy:  0.8131| F1:  0.6765 | Balanced Accuracy:  0.8367 | Dom Avg Accuracy:  0.8461 |
                    Domain 1 Accuracy:  0.8529| Domain 1 F1:  0.8555 | Domain 1 Balanced Accuracy:  0.8529 | 
                    Domain 2 Accuracy:  0.7978| Domain 2 F1:  0.5048 | Domain 2 Balanced Accuracy:  0.8393
Epoch 19 Val   | Loss:  0.1629 | Accuracy:  0.7585| F1:  0.6000 | Balanced Accuracy:  0.7787 | Dom Avg Accuracy:  0.7771 |
                Domain 1 Accuracy:  0.7533| Domain 1 F1:  0.7625 | Domain 1 Balanced Accuracy:  0.7533 | 
                Domain 2 Accuracy:  0.7605| Domain 2 F1:  0.4512 | Domain 2 Balanced Accuracy:  0.8009


100%|█████████▉| 1575/1576 [01:04<00:00, 24.49it/s]


Epoch 20 Train | Loss:  0.0484 | Accuracy:  0.8149| F1:  0.6791 | Balanced Accuracy:  0.8387 | Dom Avg Accuracy:  0.8491 |
                    Domain 1 Accuracy:  0.8546| Domain 1 F1:  0.8567 | Domain 1 Balanced Accuracy:  0.8546 | 
                    Domain 2 Accuracy:  0.7997| Domain 2 F1:  0.5093 | Domain 2 Balanced Accuracy:  0.8437
Epoch 20 Val   | Loss:  0.1793 | Accuracy:  0.6919| F1:  0.5689 | Balanced Accuracy:  0.7715 | Dom Avg Accuracy:  0.7703 |
                Domain 1 Accuracy:  0.7533| Domain 1 F1:  0.7846 | Domain 1 Balanced Accuracy:  0.7533 | 
                Domain 2 Accuracy:  0.6682| Domain 2 F1:  0.3959 | Domain 2 Balanced Accuracy:  0.7873


100%|█████████▉| 1575/1576 [01:04<00:00, 24.55it/s]


Epoch 21 Train | Loss:  0.0463 | Accuracy:  0.8238| F1:  0.6910 | Balanced Accuracy:  0.8462 | Dom Avg Accuracy:  0.8555 |
                    Domain 1 Accuracy:  0.8683| Domain 1 F1:  0.8704 | Domain 1 Balanced Accuracy:  0.8683 | 
                    Domain 2 Accuracy:  0.8067| Domain 2 F1:  0.5150 | Domain 2 Balanced Accuracy:  0.8427
Epoch 21 Val   | Loss:  0.1687 | Accuracy:  0.7459| F1:  0.5922 | Balanced Accuracy:  0.7760 | Dom Avg Accuracy:  0.7878 |
                Domain 1 Accuracy:  0.7760| Domain 1 F1:  0.7807 | Domain 1 Balanced Accuracy:  0.7760 | 
                Domain 2 Accuracy:  0.7344| Domain 2 F1:  0.4345 | Domain 2 Balanced Accuracy:  0.7996


100%|█████████▉| 1575/1576 [01:04<00:00, 24.52it/s]


Epoch 22 Train | Loss:  0.0438 | Accuracy:  0.8272| F1:  0.6975 | Balanced Accuracy:  0.8519 | Dom Avg Accuracy:  0.8624 |
                    Domain 1 Accuracy:  0.8783| Domain 1 F1:  0.8804 | Domain 1 Balanced Accuracy:  0.8783 | 
                    Domain 2 Accuracy:  0.8076| Domain 2 F1:  0.5183 | Domain 2 Balanced Accuracy:  0.8465
Epoch 22 Val   | Loss:  0.1759 | Accuracy:  0.7385| F1:  0.5905 | Balanced Accuracy:  0.7777 | Dom Avg Accuracy:  0.7873 |
                Domain 1 Accuracy:  0.7600| Domain 1 F1:  0.7698 | Domain 1 Balanced Accuracy:  0.7600 | 
                Domain 2 Accuracy:  0.7303| Domain 2 F1:  0.4416 | Domain 2 Balanced Accuracy:  0.8147


100%|█████████▉| 1575/1576 [01:04<00:00, 24.51it/s]


Epoch 23 Train | Loss:  0.0428 | Accuracy:  0.8307| F1:  0.7026 | Balanced Accuracy:  0.8555 | Dom Avg Accuracy:  0.8664 |
                    Domain 1 Accuracy:  0.8817| Domain 1 F1:  0.8836 | Domain 1 Balanced Accuracy:  0.8817 | 
                    Domain 2 Accuracy:  0.8111| Domain 2 F1:  0.5245 | Domain 2 Balanced Accuracy:  0.8510
Epoch 23 Val   | Loss:  0.1843 | Accuracy:  0.7774| F1:  0.6160 | Balanced Accuracy:  0.7867 | Dom Avg Accuracy:  0.7893 |
                Domain 1 Accuracy:  0.7707| Domain 1 F1:  0.7725 | Domain 1 Balanced Accuracy:  0.7707 | 
                Domain 2 Accuracy:  0.7800| Domain 2 F1:  0.4697 | Domain 2 Balanced Accuracy:  0.8080


100%|█████████▉| 1575/1576 [01:04<00:00, 24.31it/s]


Epoch 24 Train | Loss:  0.0406 | Accuracy:  0.8390| F1:  0.7132 | Balanced Accuracy:  0.8611 | Dom Avg Accuracy:  0.8719 |
                    Domain 1 Accuracy:  0.8877| Domain 1 F1:  0.8891 | Domain 1 Balanced Accuracy:  0.8877 | 
                    Domain 2 Accuracy:  0.8202| Domain 2 F1:  0.5368 | Domain 2 Balanced Accuracy:  0.8561
Epoch 24 Val   | Loss:  0.1700 | Accuracy:  0.7911| F1:  0.6250 | Balanced Accuracy:  0.7883 | Dom Avg Accuracy:  0.7846 |
                Domain 1 Accuracy:  0.7720| Domain 1 F1:  0.7729 | Domain 1 Balanced Accuracy:  0.7720 | 
                Domain 2 Accuracy:  0.7985| Domain 2 F1:  0.4767 | Domain 2 Balanced Accuracy:  0.7972


100%|█████████▉| 1575/1576 [01:04<00:00, 24.50it/s]


Epoch 25 Train | Loss:  0.0379 | Accuracy:  0.8463| F1:  0.7254 | Balanced Accuracy:  0.8704 | Dom Avg Accuracy:  0.8811 |
                    Domain 1 Accuracy:  0.9034| Domain 1 F1:  0.9050 | Domain 1 Balanced Accuracy:  0.9034 | 
                    Domain 2 Accuracy:  0.8243| Domain 2 F1:  0.5428 | Domain 2 Balanced Accuracy:  0.8589
Epoch 25 Val   | Loss:  0.1830 | Accuracy:  0.7489| F1:  0.5891 | Balanced Accuracy:  0.7707 | Dom Avg Accuracy:  0.7707 |
                Domain 1 Accuracy:  0.7467| Domain 1 F1:  0.7558 | Domain 1 Balanced Accuracy:  0.7467 | 
                Domain 2 Accuracy:  0.7497| Domain 2 F1:  0.4404 | Domain 2 Balanced Accuracy:  0.7948


100%|█████████▉| 1575/1576 [01:04<00:00, 24.55it/s]


Epoch 26 Train | Loss:  0.0373 | Accuracy:  0.8544| F1:  0.7369 | Balanced Accuracy:  0.8770 | Dom Avg Accuracy:  0.8878 |
                    Domain 1 Accuracy:  0.9086| Domain 1 F1:  0.9098 | Domain 1 Balanced Accuracy:  0.9086 | 
                    Domain 2 Accuracy:  0.8335| Domain 2 F1:  0.5579 | Domain 2 Balanced Accuracy:  0.8670
Epoch 26 Val   | Loss:  0.1916 | Accuracy:  0.7333| F1:  0.5804 | Balanced Accuracy:  0.7679 | Dom Avg Accuracy:  0.7759 |
                Domain 1 Accuracy:  0.7507| Domain 1 F1:  0.7599 | Domain 1 Balanced Accuracy:  0.7507 | 
                Domain 2 Accuracy:  0.7267| Domain 2 F1:  0.4312 | Domain 2 Balanced Accuracy:  0.8011

Epoch 24 Val   | Loss:  0.1700 | Accuracy:  0.7911| F1:  0.6250 | Balanced Accuracy:  0.7883 | Dom Avg Accuracy:  0.7846 |
                Domain 1 Accuracy:  0.7720| Domain 1 F1:  0.7729 | Domain 1 Balanced Accuracy:  0.7720 | 
                Domain 2 Accuracy:  0.7985| Domain 2 F1:  0.4767 | Domain 2 Balanced Accuracy:  0.

---
# Hinge Loss

In [7]:
# BERT - hinge work in progress

MAX_SENTENCE_LENGTH = 256
MIN_FREQUENCY = 40 # because 40 is statistical sample requirement
MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE = False
LOW_FREQ_TOKEN = False
CLS = True
PAD_FRONT = False
W2V_CONTEXT_WINDOW = 5 # 2 to left, 2 to right
cropped_train_data = crop_sentence_length(train_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE)
cropped_val_data = crop_sentence_length(val_data, max_sentence_length =  MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_test_data = crop_sentence_length(test_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_future_data = crop_sentence_length(data_test, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
raw_token_pytorch_map = get_raw_token_pytorch_map(data = cropped_train_data, min_frequency = MIN_FREQUENCY)
train_x, train_y, val_x, val_y, test_x, test_y, train_dom, val_dom, test_dom, future_x = Data_Factory(cropped_train_data, \
                                                              cropped_val_data, \
                                                                cropped_test_data, \
                                                                    cropped_future_data, \
                                                                        MAX_SENTENCE_LENGTH, \
                                                                            raw_token_pytorch_map, \
                                                                                CLS=CLS, \
                                                                                    low_freq_special_token=LOW_FREQ_TOKEN, \
                                                                                        pad_front=PAD_FRONT)

# train_x = train_x[:8]
# train_y = train_y[:8]
# train_dom = train_dom[:8]
# val_dom = val_dom[:8]
# val_x = val_x[:8]
# val_y = val_y[:8]

pos_prior, neg_prior = get_distribution(train_y)
print('class prior:', pos_prior, neg_prior)
pos_dom_prior, neg_dom_prior = get_distribution(train_dom)
print('domain prior:', pos_dom_prior, neg_dom_prior)
dom1_pos_prior, dom1_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label1])
print('dom1 class prior:', dom1_pos_prior, dom1_neg_prior)
dom2_pos_prior, dom2_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label2])
print('dom2 class prior:', dom2_pos_prior, dom2_neg_prior)
pretrain_x, pretrain_y, pretrain_mask, pretrain_dom, preval_x, preval_y, preval_mask, preval_dom = BERT_pretrain_DataFactory(train_data, val_data, SEED, raw_token_pytorch_map, MAX_SENTENCE_LENGTH)

print('---')

class BERT_Hinge_config:
    # ----------------- architectual hyperparameters ----------------- #
    d_model = 256
    d_ff = 1024 # = 4* d_model
    n_heads = 8
    dropout = 0.1
    e_layers = 6
    embedding_aggregation = 'cls' # TODO
    n_mlp_layers = 1
    # n_mlp_dom_layers = 1
    res_learning = False
    activation = nn.ReLU()
    mask_flag = False # causal mask
    train_embedding = True
    # ----------------- optimisation hyperparameters ----------------- #
    random_state = SEED
    batch_size = 8
    epochs = 32
    lr = 1e-3
    patience = 10
    # loss = nn.BCELoss()
    # loss = nn.BCELoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    loss = HingeLoss()
    # validation_loss = nn.BCELoss()
    # validation_loss = nn.BCELoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    validation_loss = HingeLoss()
    domain_loss = nn.BCELoss()
    pretrain_loss = nn.CrossEntropyLoss()
    pretrain_validation_loss = nn.CrossEntropyLoss()
    alpha = 0
    gradient_reversal_every_n_epoch = 1
    regularisation_loss = None
    scheduler = False
    grad_clip = False
    # ----------------- operation hyperparameters ----------------- #
    d_output = 1
    seq_len = MAX_SENTENCE_LENGTH
    n_unique_tokens = len(raw_token_pytorch_map)
    # ----------------- saving hyperparameters ----------------- #
    rootpath = home_directory + './'
    saving_address = home_directory + f'./results/'
    name = f'BERT_Hinge'

model = BERT_Hinge(BERT_Hinge_config) # initialise the model

# train the model (all cells except this one will print training log and evaluation at each batch)
best_epoch = model.fit(train_x, train_y, train_dom, val_x, val_y, val_dom)
print()

# as model automatically saves best epoch, will now load the best epoch and evaluate on test set
model.load()
model.eval(val_x, val_y, val_dom, best_epoch, evaluation_mode = True)
model.eval(test_x, test_y, test_dom, best_epoch, evaluation_mode = True)

100%|██████████| 12600/12600 [00:00<00:00, 22925.52it/s]
100%|██████████| 2700/2700 [00:00<00:00, 31301.90it/s]
100%|██████████| 2700/2700 [00:00<00:00, 8854.81it/s] 
100%|██████████| 4000/4000 [00:00<00:00, 19476.57it/s]
100%|██████████| 12600/12600 [00:01<00:00, 7186.27it/s]
100%|██████████| 12600/12600 [00:00<00:00, 43106.40it/s]
100%|██████████| 2700/2700 [00:00<00:00, 47805.40it/s]
100%|██████████| 2700/2700 [00:00<00:00, 33598.43it/s]
100%|██████████| 4000/4000 [00:00<00:00, 29364.68it/s]


class prior: 0.75 0.25
domain prior: 0.0 1.0
dom1 class prior: 0.5 0.5
dom2 class prior: 0.11538461538461539 0.8846153846153846


100%|██████████| 12600/12600 [00:02<00:00, 5948.11it/s]
100%|██████████| 2700/2700 [00:00<00:00, 4334.59it/s]


---


  from .autonotebook import tqdm as notebook_tqdm
 50%|█████     | 1/2 [00:01<00:01,  1.79s/it]
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 1 Train | Loss:  0.9633 | Accuracy:  0.7500| F1:  0.8571 | Balanced Accuracy:  0.5000 | Dom Avg Accuracy:     nan |
                    Domain 1 Accuracy:  0.7500| Domain 1 F1:  0.8571 | Domain 1 Balanced Accuracy:  0.5000 | 
                    Domain 2 Accuracy:     nan| Domain 2 F1:  0.0000 | Domain 2 Balanced Accuracy:     nan


  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 1 Val   | Loss:  1.1302 | Accuracy:  0.3750| F1:  0.5455 | Balanced Accuracy:  0.5000 | Dom Avg Accuracy:     nan |
                Domain 1 Accuracy:  0.3750| Domain 1 F1:  0.5455 | Domain 1 Balanced Accuracy:  0.5000 | 
                Domain 2 Accuracy:     nan| Domain 2 F1:  0.0000 | Domain 2 Balanced Accuracy:     nan


 50%|█████     | 1/2 [00:01<00:01,  1.43s/it]
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 2 Train | Loss:  0.7216 | Accuracy:  0.7500| F1:  0.8571 | Balanced Accuracy:  0.5000 | Dom Avg Accuracy:     nan |
                    Domain 1 Accuracy:  0.7500| Domain 1 F1:  0.8571 | Domain 1 Balanced Accuracy:  0.5000 | 
                    Domain 2 Accuracy:     nan| Domain 2 F1:  0.0000 | Domain 2 Balanced Accuracy:     nan


  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 2 Val   | Loss:  1.4460 | Accuracy:  0.3750| F1:  0.5455 | Balanced Accuracy:  0.5000 | Dom Avg Accuracy:     nan |
                Domain 1 Accuracy:  0.3750| Domain 1 F1:  0.5455 | Domain 1 Balanced Accuracy:  0.5000 | 
                Domain 2 Accuracy:     nan| Domain 2 F1:  0.0000 | Domain 2 Balanced Accuracy:     nan


 50%|█████     | 1/2 [00:02<00:02,  2.20s/it]
  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 3 Train | Loss:  0.5789 | Accuracy:  0.7500| F1:  0.8571 | Balanced Accuracy:  0.5000 | Dom Avg Accuracy:     nan |
                    Domain 1 Accuracy:  0.7500| Domain 1 F1:  0.8571 | Domain 1 Balanced Accuracy:  0.5000 | 
                    Domain 2 Accuracy:     nan| Domain 2 F1:  0.0000 | Domain 2 Balanced Accuracy:     nan


  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 3 Val   | Loss:  1.2411 | Accuracy:  0.3750| F1:  0.5455 | Balanced Accuracy:  0.5000 | Dom Avg Accuracy:     nan |
                Domain 1 Accuracy:  0.3750| Domain 1 F1:  0.5455 | Domain 1 Balanced Accuracy:  0.5000 | 
                Domain 2 Accuracy:     nan| Domain 2 F1:  0.0000 | Domain 2 Balanced Accuracy:     nan


  0%|          | 0/2 [00:02<?, ?it/s]


KeyboardInterrupt: 

---

# Historic Best

In [None]:
# BERT - SoftmaxBCELoss+OptBalAccu

MAX_SENTENCE_LENGTH = 256
MIN_FREQUENCY = 40 # because 40 is statistical sample requirement
MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE = False
LOW_FREQ_TOKEN = False
CLS = True
PAD_FRONT = False
W2V_CONTEXT_WINDOW = 5 # 2 to left, 2 to right
cropped_train_data = crop_sentence_length(train_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE)
cropped_val_data = crop_sentence_length(val_data, max_sentence_length =  MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_test_data = crop_sentence_length(test_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_future_data = crop_sentence_length(data_test, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
raw_token_pytorch_map = get_raw_token_pytorch_map(data = cropped_train_data, min_frequency = MIN_FREQUENCY)
train_x, train_y, val_x, val_y, test_x, test_y, train_dom, val_dom, test_dom, future_x = Data_Factory(cropped_train_data, \
                                                              cropped_val_data, \
                                                                cropped_test_data, \
                                                                    cropped_future_data, \
                                                                        MAX_SENTENCE_LENGTH, \
                                                                            raw_token_pytorch_map, \
                                                                                CLS=CLS, \
                                                                                    low_freq_special_token=LOW_FREQ_TOKEN, \
                                                                                        pad_front=PAD_FRONT)
pos_prior, neg_prior = get_distribution(train_y)
print('class prior:', pos_prior, neg_prior)
pos_dom_prior, neg_dom_prior = get_distribution(train_dom)
print('domain prior:', pos_dom_prior, neg_dom_prior)
dom1_pos_prior, dom1_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label1])
print('dom1 class prior:', dom1_pos_prior, dom1_neg_prior)
dom2_pos_prior, dom2_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label2])
print('dom2 class prior:', dom2_pos_prior, dom2_neg_prior)
pretrain_x, pretrain_y, pretrain_mask, pretrain_dom, preval_x, preval_y, preval_mask, preval_dom = BERT_pretrain_DataFactory(train_data, val_data, SEED, raw_token_pytorch_map, MAX_SENTENCE_LENGTH)

print('---')

class BERT_DANN_config:
    # ----------------- architectual hyperparameters ----------------- #
    d_model = 256
    d_ff = 1024 # = 4* d_model
    n_heads = 8
    dropout = 0.1
    e_layers = 6 # actually 8 was better
    embedding_aggregation = 'cls' # TODO
    n_mlp_clf_layers = 1
    n_mlp_dom_layers = 1
    res_learning = False
    activation = nn.ReLU()
    mask_flag = False # causal mask
    train_embedding = True
    # ----------------- optimisation hyperparameters ----------------- #
    random_state = SEED
    batch_size = 8
    epochs = 32
    lr = 1e-5
    patience = 10
    # loss = nn.BCELoss()
    loss = nn.BCELoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    # validation_loss = nn.BCELoss()
    validation_loss = nn.BCELoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    domain_loss = nn.BCELoss()
    pretrain_loss = nn.CrossEntropyLoss()
    pretrain_validation_loss = nn.CrossEntropyLoss()
    alpha = 0.1
    gradient_reversal_every_n_epoch = 1
    regularisation_loss = None
    scheduler = False
    grad_clip = False
    # ----------------- operation hyperparameters ----------------- #
    d_output = 2
    seq_len = MAX_SENTENCE_LENGTH
    n_unique_tokens = len(raw_token_pytorch_map)
    # ----------------- saving hyperparameters ----------------- #
    rootpath = home_directory + './'
    saving_address = home_directory + f'./results/'
    name = f'BERT_DANN'

model = BERT_DANN(BERT_DANN_config) # initialise the model

# train the model (all cells except this one will print training log and evaluation at each batch)
best_epoch = model.fit(train_x, train_y, train_dom, val_x, val_y, val_dom)
print()

# as model automatically saves best epoch, will now load the best epoch and evaluate on test set
model.load()
model.eval(val_x, val_y, val_dom, best_epoch, evaluation_mode = True)
model.eval(test_x, test_y, test_dom, best_epoch, evaluation_mode = True)

In [None]:
# BERT - with low freq token - SoftmaxBCELoss+OptBalAccu

MAX_SENTENCE_LENGTH = 256
MIN_FREQUENCY = 40 # because 40 is statistical sample requirement
MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE = False
LOW_FREQ_TOKEN = True
CLS = True
PAD_FRONT = False
W2V_CONTEXT_WINDOW = 5 # 2 to left, 2 to right
cropped_train_data = crop_sentence_length(train_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE)
cropped_val_data = crop_sentence_length(val_data, max_sentence_length =  MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_test_data = crop_sentence_length(test_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_future_data = crop_sentence_length(data_test, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
raw_token_pytorch_map = get_raw_token_pytorch_map(data = cropped_train_data, min_frequency = MIN_FREQUENCY)
train_x, train_y, val_x, val_y, test_x, test_y, train_dom, val_dom, test_dom, future_x = Data_Factory(cropped_train_data, \
                                                              cropped_val_data, \
                                                                cropped_test_data, \
                                                                    cropped_future_data, \
                                                                        MAX_SENTENCE_LENGTH, \
                                                                            raw_token_pytorch_map, \
                                                                                CLS=CLS, \
                                                                                    low_freq_special_token=LOW_FREQ_TOKEN, \
                                                                                        pad_front=PAD_FRONT)
pos_prior, neg_prior = get_distribution(train_y)
print('class prior:', pos_prior, neg_prior)
pos_dom_prior, neg_dom_prior = get_distribution(train_dom)
print('domain prior:', pos_dom_prior, neg_dom_prior)
dom1_pos_prior, dom1_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label1])
print('dom1 class prior:', dom1_pos_prior, dom1_neg_prior)
dom2_pos_prior, dom2_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label2])
print('dom2 class prior:', dom2_pos_prior, dom2_neg_prior)
pretrain_x, pretrain_y, pretrain_mask, pretrain_dom, preval_x, preval_y, preval_mask, preval_dom = BERT_pretrain_DataFactory(train_data, val_data, SEED, raw_token_pytorch_map, MAX_SENTENCE_LENGTH)

print('---')

class BERT_DANN_config:
    # ----------------- architectual hyperparameters ----------------- #
    d_model = 256
    d_ff = 1024 # = 4* d_model
    n_heads = 8
    dropout = 0.1
    e_layers = 6
    embedding_aggregation = 'cls' # TODO
    n_mlp_clf_layers = 1
    n_mlp_dom_layers = 1
    res_learning = False
    activation = nn.ReLU()
    mask_flag = False # causal mask
    train_embedding = True
    # ----------------- optimisation hyperparameters ----------------- #
    random_state = SEED
    batch_size = 8
    epochs = 32
    lr = 1e-5
    patience = 10
    # loss = nn.BCELoss()
    loss = nn.BCELoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    # validation_loss = nn.BCELoss()
    validation_loss = nn.BCELoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    domain_loss = nn.BCELoss()
    pretrain_loss = nn.CrossEntropyLoss()
    pretrain_validation_loss = nn.CrossEntropyLoss()
    alpha = 0.1
    gradient_reversal_every_n_epoch = 1
    regularisation_loss = None
    scheduler = False
    grad_clip = False
    # ----------------- operation hyperparameters ----------------- #
    d_output = 2
    seq_len = MAX_SENTENCE_LENGTH
    n_unique_tokens = len(raw_token_pytorch_map)
    # ----------------- saving hyperparameters ----------------- #
    rootpath = home_directory + './'
    saving_address = home_directory + f'./results/'
    name = f'BERT_DANN'

model = BERT_DANN(BERT_DANN_config) # initialise the model

# train the model (all cells except this one will print training log and evaluation at each batch)
best_epoch = model.fit(train_x, train_y, train_dom, val_x, val_y, val_dom)
print()

# as model automatically saves best epoch, will now load the best epoch and evaluate on test set
model.load()
model.eval(val_x, val_y, val_dom, best_epoch, evaluation_mode = True)
model.eval(test_x, test_y, test_dom, best_epoch, evaluation_mode = True)

---
---
# LSTM Graveyard

In [None]:
# LSTM
MAX_SENTENCE_LENGTH = 64
MIN_FREQUENCY = 0 # because 40 is statistical sample requirement
MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE = False
LOW_FREQ_TOKEN = False
CLS = False
PAD_FRONT = True
W2V_CONTEXT_WINDOW = 5 # 2 to left, 2 to right
cropped_train_data = crop_sentence_length(train_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE)
cropped_val_data = crop_sentence_length(val_data, max_sentence_length =  MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_test_data = crop_sentence_length(test_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_future_data = crop_sentence_length(data_test, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
raw_token_pytorch_map = get_raw_token_pytorch_map(data = cropped_train_data, min_frequency = MIN_FREQUENCY)
train_x, train_y, val_x, val_y, test_x, test_y, train_dom, val_dom, test_dom, future_x = Data_Factory(cropped_train_data, \
                                                              cropped_val_data, \
                                                                cropped_test_data, \
                                                                    cropped_future_data, \
                                                                        MAX_SENTENCE_LENGTH, \
                                                                            raw_token_pytorch_map, \
                                                                                CLS=CLS, \
                                                                                    low_freq_special_token=LOW_FREQ_TOKEN, \
                                                                                        pad_front=PAD_FRONT)
pos_prior, neg_prior = get_distribution(train_y)
print('class prior:', pos_prior, neg_prior)
pos_dom_prior, neg_dom_prior = get_distribution(train_dom)
print('domain prior:', pos_dom_prior, neg_dom_prior)
dom1_pos_prior, dom1_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label1])
print('dom1 class prior:', dom1_pos_prior, dom1_neg_prior)
dom2_pos_prior, dom2_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label2])
print('dom2 class prior:', dom2_pos_prior, dom2_neg_prior)
pretrain_x, pretrain_y, pretrain_mask, pretrain_dom, preval_x, preval_y, preval_mask, preval_dom = BERT_pretrain_DataFactory(train_data, val_data, SEED, raw_token_pytorch_map, MAX_SENTENCE_LENGTH)

print('---')

class LSTM_config:
    # ----------------- architectual hyperparameters ----------------- #
    d_model = 256
    n_recurrent_layers = 2
    bidirectional = False
    n_heads = 8
    dropout = 0.1
    n_mlp_layers = 1
    flatten = False
    activation = nn.ReLU()
    res_learning = True
    mask_flag = False # TODO
    train_embedding = True
    # ----------------- optimisation hyperparameters ----------------- #
    random_state = SEED
    batch_size = 8
    epochs = 32
    lr = 1e-5
    patience = 4
    loss = nn.BCELoss()
    # loss = nn.BCELoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    validation_loss = nn.BCELoss()
    # validation_loss = nn.BCELoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    pretrain_loss = nn.CrossEntropyLoss()
    pretrain_validation_loss = nn.CrossEntropyLoss()
    regularisation_loss = None
    scheduler = True
    grad_clip = False
    # ----------------- operation hyperparameters ----------------- #
    d_output = 2
    seq_len = MAX_SENTENCE_LENGTH
    n_unique_tokens = len(raw_token_pytorch_map)
    # ----------------- saving hyperparameters ----------------- #
    rootpath = home_directory + './'
    saving_address = home_directory +  f'./results/'
    name = f'LSTM_Classifier'



model = LSTM(LSTM_config) # initialise the model

# train the model (all cells except this one will print training log and evaluation at each batch)
best_epoch = model.fit(train_x, train_y, train_dom, val_x, val_y, val_dom)
print()

# as model automatically saves best epoch, will now load the best epoch and evaluate on test set
model.load()
model.eval(val_x, val_y, val_dom, best_epoch, evaluation_mode = True)
model.eval(test_x, test_y, test_dom, best_epoch, evaluation_mode = True)

In [None]:
# LSTM_DANN

MAX_SENTENCE_LENGTH = 128
MIN_FREQUENCY = 0 # because 40 is statistical sample requirement
MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE = False
LOW_FREQ_TOKEN = False
CLS = True
PAD_FRONT = False
W2V_CONTEXT_WINDOW = 5 # 2 to left, 2 to right
cropped_train_data = crop_sentence_length(train_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = MAKE_CROPPED_REMAINS_INTO_NEW_INSTANCE)
cropped_val_data = crop_sentence_length(val_data, max_sentence_length =  MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_test_data = crop_sentence_length(test_data, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
cropped_future_data = crop_sentence_length(data_test, max_sentence_length = MAX_SENTENCE_LENGTH, make_cropped_remains_into_new_instance = False)
raw_token_pytorch_map = get_raw_token_pytorch_map(data = cropped_train_data, min_frequency = MIN_FREQUENCY)
train_x, train_y, val_x, val_y, test_x, test_y, train_dom, val_dom, test_dom, future_x = Data_Factory(cropped_train_data, \
                                                              cropped_val_data, \
                                                                cropped_test_data, \
                                                                    cropped_future_data, \
                                                                        MAX_SENTENCE_LENGTH, \
                                                                            raw_token_pytorch_map, \
                                                                                CLS=CLS, \
                                                                                    low_freq_special_token=LOW_FREQ_TOKEN, \
                                                                                        pad_front=PAD_FRONT)
pos_prior, neg_prior = get_distribution(train_y)
print('class prior:', pos_prior, neg_prior)
pos_dom_prior, neg_dom_prior = get_distribution(train_dom)
print('domain prior:', pos_dom_prior, neg_dom_prior)
dom1_pos_prior, dom1_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label1])
print('dom1 class prior:', dom1_pos_prior, dom1_neg_prior)
dom2_pos_prior, dom2_neg_prior = get_distribution([[0, 1] if label == 1 else [1, 0] for label in label2])
print('dom2 class prior:', dom2_pos_prior, dom2_neg_prior)
pretrain_x, pretrain_y, pretrain_mask, pretrain_dom, preval_x, preval_y, preval_mask, preval_dom = BERT_pretrain_DataFactory(train_data, val_data, SEED, raw_token_pytorch_map, MAX_SENTENCE_LENGTH)

print('---')

class LSTM_DANN_config:
    # ----------------- architectual hyperparameters ----------------- #
    d_model = 512
    n_recurrent_layers = 2
    bidirectional = False
    n_heads = 8
    dropout = 0.1
    n_mlp_clf_layers = 1
    n_mlp_dom_layers = 1
    flatten = False
    activation = nn.ReLU()
    res_learning = True
    mask_flag = False # TODO
    train_embedding = True
    # ----------------- optimisation hyperparameters ----------------- #
    random_state = SEED
    batch_size = 128
    epochs = 32
    lr = 1e-5
    patience = 5
    loss = nn.BCELoss()
    # loss = nn.BCELoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    validation_loss = nn.BCELoss()
    # validation_loss = nn.BCELoss(weight=torch.FloatTensor([pos_prior, neg_prior]))
    domain_loss = nn.BCELoss()
    pretrain_loss = nn.CrossEntropyLoss()
    pretrain_validation_loss = nn.CrossEntropyLoss()
    alpha = 0.1
    regularisation_loss = None
    scheduler = True
    grad_clip = False
    # ----------------- operation hyperparameters ----------------- #
    d_output = 2
    seq_len = MAX_SENTENCE_LENGTH
    n_unique_tokens = len(raw_token_pytorch_map)
    # ----------------- saving hyperparameters ----------------- #
    rootpath = home_directory + './'
    saving_address = home_directory +  f'./results/'
    name = f'LSTM_DANN'

model = LSTM_DANN(LSTM_DANN_config) # initialise the model

# train the model (all cells except this one will print training log and evaluation at each batch)
best_epoch = model.fit(train_x, train_y, train_dom, val_x, val_y, val_dom)
print()

# as model automatically saves best epoch, will now load the best epoch and evaluate on test set
model.load()
model.eval(val_x, val_y, val_dom, best_epoch, evaluation_mode = True)
model.eval(test_x, test_y, test_dom, best_epoch, evaluation_mode = True)