In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import sys, os, math
import torch
import json

device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
print(device)

epochs = 100_000
val_epoch = 1000
num_val = 100
batch_size = 128
virus_dataset_name = "corpus_1000_Viruses"
cellular_dataset_name = "corpus_1000_cellular"
model_name = "ML"
max_seq_len = 1000

sys.path.insert(0, '../dlp')
from data_access import PQDataAccess
virus_da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq_new/{virus_dataset_name}", batch_size)
cellular_da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq_new/{cellular_dataset_name}", batch_size)

cuda:1
 WORLD_SIZE=1 , LOCAL_WORLD_SIZE=1,RANK =0,LOCAL_RANK = 0 


In [2]:
import random
index2name_file = "../data/taxonomy_index.json"

# Check if the file exists
if os.path.exists(index2name_file):
    with open(index2name_file, "rb") as f:
        index2name = json.load(f)

tax_vocab_sizes = {
    int(k): len(v) for k,v in index2name.items()
}

level_encoder = {
    int(k): {name: idx + 1 for idx, name in enumerate(v)} for k,v in index2name.items()
}

level_decoder = {
    int(k): {idx + 1: name for idx, name in enumerate(v)} for k,v in index2name.items()
}

for k, v in level_decoder.items():
    level_decoder[k][0] = "NOT DEFINED"


def encode_lineage(lineage_str):
    taxes_str = lineage_str.split(", ")

    encoded = {int(k): 0 for k in index2name.keys()}
    
    for i, tax_str in enumerate(taxes_str):
        encoded[i] = level_encoder[i][tax_str]

    return encoded


from sklearn.preprocessing import LabelEncoder, OneHotEncoder


def one_hot_encode_sequence(seq):
    amino_acids = '-ACDEFGHIKLMNPQRSTVWY'  # 20 standard amino acids
    encoder = OneHotEncoder(categories=[list(amino_acids)], sparse_output=False)

    # Convert sequence into a 2D array (each amino acid as a separate row)
    sequence_array = np.array([aa if aa in amino_acids else '-' for aa in seq] + ['-' for _ in range(max_seq_len - len(seq))]).reshape(-1, 1)

    # Perform one-hot encoding
    one_hot_encoded = encoder.fit_transform(sequence_array)

    # Flatten the one-hot encoded array to a single vector for each sequence
    return one_hot_encoded.flatten()

def mix_data_to_tensor_batch(b_virues, b_cellular, max_seq_len=max_seq_len, partition=0.25):
    split_point = int(len(b_virues) * partition)
    b = b_virues[:split_point] + b_cellular[-len(b_virues) + split_point:]
    random.shuffle(b)  # In-place shuffle
    
    inputs = np.array([one_hot_encode_sequence(e['Sequence']) for e in b])

    tax_ids = [encode_lineage(e['Taxonomic_lineage__ALL_']) for e in b]
    combined_dict = {}
    for d in tax_ids:
        for key, value in d.items():
            combined_dict.setdefault(key, []).append(value)

    encoded = {k: np.array(v) for k, v in combined_dict.items()}
    return {"sequences": inputs, "labels":encoded}

def val_loader():
    val_batches = [virus_da.get_batch() for _ in range(num_val // 2)] + [cellular_da.get_batch() for _ in range(num_val // 2)]

    inputs = np.array([one_hot_encode_sequence(e['Sequence']) for b in val_batches for e in b])

    tax_ids = [encode_lineage(e['Taxonomic_lineage__ALL_']) for b in val_batches for e in b]
    combined_dict = {}
    for d in tax_ids:
        for key, value in d.items():
            combined_dict.setdefault(key, []).append(value)

    encoded = {k: np.array(v) for k, v in combined_dict.items()}
    return {"sequences": inputs, "labels":encoded}

In [3]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, \
    classification_report, confusion_matrix
# import xgboost as xgb
#
# # Set up the DMatrix for XGBoost (with GPU)
# dtrain = xgb.DMatrix(X_train, label=y_train)
# dtest = xgb.DMatrix(X_test, label=y_test)
#
# # Set GPU parameters
# params = {
#     'tree_method': 'gpu_hist',  # Use GPU acceleration
#     'device': 'cuda:0',
#     'objective': 'binary:logistic',
#     'eval_metric': 'logloss',
#     'max_depth': 6
# }
#
# # Train the model
# bst = xgb.train(params, dtrain, num_boost_round=100)
#
# # Predict and evaluate
# y_pred = (bst.predict(dtest) > 0.5).astype(int)
# y_proba = bst.predict(dtest)

In [4]:
def get_partition_ratio(epoch, decay_epochs=100000):
    """
    Calculate partition ratio that decreases from 8/16 to 1/16 in steps
    """
    # Calculate how many epochs before each step down
    epochs_per_step = decay_epochs // 7  # 7 steps from 8/16 down to 1/16
    
    # Calculate current step based on epoch
    step = min(epoch // epochs_per_step, 7)  # Max 7 steps down from 8
    
    # Map step to fraction
    fraction = (8 - step) / 16
    
    return fraction

In [None]:
val_inputs = val_loader()


def evaluate(clf):
    y_test = val_inputs["labels"][0]
    y_pred = clf.predict(val_inputs["sequences"])
    y_proba = clf.predict_proba(val_inputs["sequences"])[:, 1]  # Probabilities for ROC-AUC
    
    # Calculate metrics
    accuracy = accuracy_score(y_test, y_pred)
    precision = precision_score(y_test, y_pred)
    recall = recall_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred)
    roc_auc = roc_auc_score(y_test, y_proba)
    conf_matrix = confusion_matrix(y_test, y_pred)
    
    print(f"Accuracy:{accuracy:.4f}")
    print(f"Precision:{precision:.4f}")
    print(f"Recall:{recall:.4f}")
    print(f"F1-Score:{f1:.4f}")
    print(f"ROC-AUC:{roc_auc:.4f}")
    print("\nConfusion Matrix:\n", conf_matrix)
    print("\nClassification Report:\n", classification_report(y_test, y_pred))
    
def get_partition_ratio(epoch, decay_epochs=100000):
    """
    Calculate partition ratio that decreases from 8/16 to 1/16 in steps
    """
    # Calculate how many epochs before each step down
    epochs_per_step = decay_epochs // 7  # 7 steps from 8/16 down to 1/16
    
    # Calculate current step based on epoch
    step = min(epoch // epochs_per_step, 7)  # Max 7 steps down from 8
    
    # Map step to fraction
    fraction = (8 - step) / 16
    
    return fraction


clf = RandomForestClassifier(n_estimators=200, random_state=42)
for epoch in tqdm(range(epochs)):
    inputs = mix_data_to_tensor_batch(virus_da.get_batch(), cellular_da.get_batch(), partition = get_partition_ratio(epoch+1))
    sequences = inputs["sequences"]
    labels = inputs["labels"][0]
    clf.fit(sequences, labels)
    
    if (epoch + 1) % val_epoch == 0:
        print(f"Epoch [{epoch + 1}/{epochs}]")
        evaluate(clf)

  1%|          | 999/100000 [06:25<11:25:35,  2.41it/s]

Epoch [1000/100000]


  1%|          | 1000/100000 [06:27<21:34:08,  1.27it/s]

Accuracy:0.6062
Precision:0.6134
Recall:0.5745
F1-Score:0.5934
ROC-AUC:0.6670

Confusion Matrix:
 [[3677 2723]
 [2317 4083]]

Classification Report:
               precision    recall  f1-score   support

           1       0.61      0.57      0.59      6400
           2       0.60      0.64      0.62      6400

    accuracy                           0.61     12800
   macro avg       0.61      0.61      0.61     12800
weighted avg       0.61      0.61      0.61     12800



  2%|▏         | 1999/100000 [13:01<10:41:52,  2.54it/s]

Epoch [2000/100000]


  2%|▏         | 2000/100000 [13:03<20:46:12,  1.31it/s]

Accuracy:0.6344
Precision:0.6032
Recall:0.7856
F1-Score:0.6824
ROC-AUC:0.6892

Confusion Matrix:
 [[5028 1372]
 [3308 3092]]

Classification Report:
               precision    recall  f1-score   support

           1       0.60      0.79      0.68      6400
           2       0.69      0.48      0.57      6400

    accuracy                           0.63     12800
   macro avg       0.65      0.63      0.63     12800
weighted avg       0.65      0.63      0.63     12800



  3%|▎         | 2999/100000 [19:36<11:03:45,  2.44it/s]

Epoch [3000/100000]


  3%|▎         | 3000/100000 [19:37<19:48:33,  1.36it/s]

Accuracy:0.6488
Precision:0.6350
Recall:0.6998
F1-Score:0.6658
ROC-AUC:0.6910

Confusion Matrix:
 [[4479 1921]
 [2575 3825]]

Classification Report:
               precision    recall  f1-score   support

           1       0.63      0.70      0.67      6400
           2       0.67      0.60      0.63      6400

    accuracy                           0.65     12800
   macro avg       0.65      0.65      0.65     12800
weighted avg       0.65      0.65      0.65     12800



  4%|▍         | 3999/100000 [26:04<10:02:21,  2.66it/s]

Epoch [4000/100000]


  4%|▍         | 4000/100000 [26:05<18:35:56,  1.43it/s]

Accuracy:0.6427
Precision:0.6268
Recall:0.7058
F1-Score:0.6639
ROC-AUC:0.6893

Confusion Matrix:
 [[4517 1883]
 [2690 3710]]

Classification Report:
               precision    recall  f1-score   support

           1       0.63      0.71      0.66      6400
           2       0.66      0.58      0.62      6400

    accuracy                           0.64     12800
   macro avg       0.65      0.64      0.64     12800
weighted avg       0.65      0.64      0.64     12800



  5%|▍         | 4999/100000 [32:41<10:03:10,  2.63it/s]

Epoch [5000/100000]


  5%|▌         | 5000/100000 [32:42<19:51:42,  1.33it/s]

Accuracy:0.6215
Precision:0.6129
Recall:0.6594
F1-Score:0.6353
ROC-AUC:0.6733

Confusion Matrix:
 [[4220 2180]
 [2665 3735]]

Classification Report:
               precision    recall  f1-score   support

           1       0.61      0.66      0.64      6400
           2       0.63      0.58      0.61      6400

    accuracy                           0.62     12800
   macro avg       0.62      0.62      0.62     12800
weighted avg       0.62      0.62      0.62     12800



  6%|▌         | 5999/100000 [39:05<9:55:17,  2.63it/s] 

Epoch [6000/100000]


  6%|▌         | 6000/100000 [39:06<19:05:05,  1.37it/s]

Accuracy:0.6272
Precision:0.6409
Recall:0.5786
F1-Score:0.6081
ROC-AUC:0.6892

Confusion Matrix:
 [[3703 2697]
 [2075 4325]]

Classification Report:
               precision    recall  f1-score   support

           1       0.64      0.58      0.61      6400
           2       0.62      0.68      0.64      6400

    accuracy                           0.63     12800
   macro avg       0.63      0.63      0.63     12800
weighted avg       0.63      0.63      0.63     12800



  7%|▋         | 6999/100000 [45:29<9:43:36,  2.66it/s] 

Epoch [7000/100000]


  7%|▋         | 7000/100000 [45:31<18:16:31,  1.41it/s]

Accuracy:0.6213
Precision:0.5924
Recall:0.7777
F1-Score:0.6725
ROC-AUC:0.6669

Confusion Matrix:
 [[4977 1423]
 [3424 2976]]

Classification Report:
               precision    recall  f1-score   support

           1       0.59      0.78      0.67      6400
           2       0.68      0.47      0.55      6400

    accuracy                           0.62     12800
   macro avg       0.63      0.62      0.61     12800
weighted avg       0.63      0.62      0.61     12800



  8%|▊         | 7999/100000 [51:54<9:32:33,  2.68it/s] 

Epoch [8000/100000]


  8%|▊         | 8000/100000 [51:55<17:48:30,  1.44it/s]

Accuracy:0.6149
Precision:0.6465
Recall:0.5070
F1-Score:0.5684
ROC-AUC:0.7010

Confusion Matrix:
 [[3245 3155]
 [1774 4626]]

Classification Report:
               precision    recall  f1-score   support

           1       0.65      0.51      0.57      6400
           2       0.59      0.72      0.65      6400

    accuracy                           0.61     12800
   macro avg       0.62      0.61      0.61     12800
weighted avg       0.62      0.61      0.61     12800



  9%|▉         | 8999/100000 [58:19<9:31:43,  2.65it/s] 

Epoch [9000/100000]


  9%|▉         | 9000/100000 [58:20<18:31:43,  1.36it/s]

Accuracy:0.6529
Precision:0.6406
Recall:0.6964
F1-Score:0.6674
ROC-AUC:0.7099

Confusion Matrix:
 [[4457 1943]
 [2500 3900]]

Classification Report:
               precision    recall  f1-score   support

           1       0.64      0.70      0.67      6400
           2       0.67      0.61      0.64      6400

    accuracy                           0.65     12800
   macro avg       0.65      0.65      0.65     12800
weighted avg       0.65      0.65      0.65     12800



 10%|▉         | 9999/100000 [1:04:50<11:05:21,  2.25it/s]

Epoch [10000/100000]


 10%|█         | 10000/100000 [1:04:52<20:00:31,  1.25it/s]

Accuracy:0.6041
Precision:0.5774
Recall:0.7770
F1-Score:0.6625
ROC-AUC:0.6602

Confusion Matrix:
 [[4973 1427]
 [3640 2760]]

Classification Report:
               precision    recall  f1-score   support

           1       0.58      0.78      0.66      6400
           2       0.66      0.43      0.52      6400

    accuracy                           0.60     12800
   macro avg       0.62      0.60      0.59     12800
weighted avg       0.62      0.60      0.59     12800



 11%|█         | 10999/100000 [1:11:16<9:17:45,  2.66it/s] 

Epoch [11000/100000]


 11%|█         | 11000/100000 [1:11:18<17:37:15,  1.40it/s]

Accuracy:0.6484
Precision:0.6306
Recall:0.7164
F1-Score:0.6708
ROC-AUC:0.6948

Confusion Matrix:
 [[4585 1815]
 [2686 3714]]

Classification Report:
               precision    recall  f1-score   support

           1       0.63      0.72      0.67      6400
           2       0.67      0.58      0.62      6400

    accuracy                           0.65     12800
   macro avg       0.65      0.65      0.65     12800
weighted avg       0.65      0.65      0.65     12800



 12%|█▏        | 11999/100000 [1:17:43<9:36:33,  2.54it/s] 

Epoch [12000/100000]


 12%|█▏        | 12000/100000 [1:17:45<18:50:36,  1.30it/s]

Accuracy:0.6413
Precision:0.6340
Recall:0.6686
F1-Score:0.6508
ROC-AUC:0.7089

Confusion Matrix:
 [[4279 2121]
 [2470 3930]]

Classification Report:
               precision    recall  f1-score   support

           1       0.63      0.67      0.65      6400
           2       0.65      0.61      0.63      6400

    accuracy                           0.64     12800
   macro avg       0.64      0.64      0.64     12800
weighted avg       0.64      0.64      0.64     12800



 13%|█▎        | 12999/100000 [1:24:16<9:25:01,  2.57it/s] 

Epoch [13000/100000]


 13%|█▎        | 13000/100000 [1:24:18<18:13:27,  1.33it/s]

Accuracy:0.6392
Precision:0.6257
Recall:0.6930
F1-Score:0.6576
ROC-AUC:0.6828

Confusion Matrix:
 [[4435 1965]
 [2653 3747]]

Classification Report:
               precision    recall  f1-score   support

           1       0.63      0.69      0.66      6400
           2       0.66      0.59      0.62      6400

    accuracy                           0.64     12800
   macro avg       0.64      0.64      0.64     12800
weighted avg       0.64      0.64      0.64     12800



 14%|█▍        | 13999/100000 [1:30:44<9:10:24,  2.60it/s] 

Epoch [14000/100000]


 14%|█▍        | 14000/100000 [1:30:45<17:25:15,  1.37it/s]

Accuracy:0.5860
Precision:0.5728
Recall:0.6770
F1-Score:0.6206
ROC-AUC:0.6477

Confusion Matrix:
 [[4333 2067]
 [3232 3168]]

Classification Report:
               precision    recall  f1-score   support

           1       0.57      0.68      0.62      6400
           2       0.61      0.49      0.54      6400

    accuracy                           0.59     12800
   macro avg       0.59      0.59      0.58     12800
weighted avg       0.59      0.59      0.58     12800



 15%|█▍        | 14999/100000 [1:37:06<8:45:01,  2.70it/s] 

Epoch [15000/100000]


 15%|█▌        | 15000/100000 [1:37:08<16:55:10,  1.40it/s]

Accuracy:0.6475
Precision:0.6010
Recall:0.8775
F1-Score:0.7134
ROC-AUC:0.7275

Confusion Matrix:
 [[5616  784]
 [3728 2672]]

Classification Report:
               precision    recall  f1-score   support

           1       0.60      0.88      0.71      6400
           2       0.77      0.42      0.54      6400

    accuracy                           0.65     12800
   macro avg       0.69      0.65      0.63     12800
weighted avg       0.69      0.65      0.63     12800



 16%|█▌        | 15999/100000 [1:43:28<8:45:31,  2.66it/s] 

Epoch [16000/100000]


 16%|█▌        | 16000/100000 [1:43:29<16:30:25,  1.41it/s]

Accuracy:0.6166
Precision:0.5757
Recall:0.8862
F1-Score:0.6980
ROC-AUC:0.7093

Confusion Matrix:
 [[5672  728]
 [4180 2220]]

Classification Report:
               precision    recall  f1-score   support

           1       0.58      0.89      0.70      6400
           2       0.75      0.35      0.47      6400

    accuracy                           0.62     12800
   macro avg       0.66      0.62      0.59     12800
weighted avg       0.66      0.62      0.59     12800



 17%|█▋        | 16999/100000 [1:49:56<8:34:46,  2.69it/s] 

Epoch [17000/100000]


 17%|█▋        | 17000/100000 [1:49:57<16:33:49,  1.39it/s]

Accuracy:0.6277
Precision:0.5880
Recall:0.8527
F1-Score:0.6960
ROC-AUC:0.6899

Confusion Matrix:
 [[5457  943]
 [3823 2577]]

Classification Report:
               precision    recall  f1-score   support

           1       0.59      0.85      0.70      6400
           2       0.73      0.40      0.52      6400

    accuracy                           0.63     12800
   macro avg       0.66      0.63      0.61     12800
weighted avg       0.66      0.63      0.61     12800



 18%|█▊        | 17999/100000 [1:56:18<8:32:26,  2.67it/s] 

Epoch [18000/100000]


 18%|█▊        | 18000/100000 [1:56:19<15:43:45,  1.45it/s]

Accuracy:0.5992
Precision:0.5711
Recall:0.7973
F1-Score:0.6655
ROC-AUC:0.6714

Confusion Matrix:
 [[5103 1297]
 [3833 2567]]

Classification Report:
               precision    recall  f1-score   support

           1       0.57      0.80      0.67      6400
           2       0.66      0.40      0.50      6400

    accuracy                           0.60     12800
   macro avg       0.62      0.60      0.58     12800
weighted avg       0.62      0.60      0.58     12800



 19%|█▉        | 18999/100000 [2:02:40<9:24:51,  2.39it/s] 

Epoch [19000/100000]


 19%|█▉        | 19000/100000 [2:02:42<17:11:52,  1.31it/s]

Accuracy:0.6111
Precision:0.5823
Recall:0.7864
F1-Score:0.6691
ROC-AUC:0.6606

Confusion Matrix:
 [[5033 1367]
 [3611 2789]]

Classification Report:
               precision    recall  f1-score   support

           1       0.58      0.79      0.67      6400
           2       0.67      0.44      0.53      6400

    accuracy                           0.61     12800
   macro avg       0.63      0.61      0.60     12800
weighted avg       0.63      0.61      0.60     12800



 20%|█▉        | 19999/100000 [2:09:18<8:38:08,  2.57it/s] 

Epoch [20000/100000]


 20%|██        | 20000/100000 [2:09:19<16:10:42,  1.37it/s]

Accuracy:0.6166
Precision:0.5857
Recall:0.7975
F1-Score:0.6754
ROC-AUC:0.6730

Confusion Matrix:
 [[5104 1296]
 [3611 2789]]

Classification Report:
               precision    recall  f1-score   support

           1       0.59      0.80      0.68      6400
           2       0.68      0.44      0.53      6400

    accuracy                           0.62     12800
   macro avg       0.63      0.62      0.60     12800
weighted avg       0.63      0.62      0.60     12800



 21%|██        | 20999/100000 [2:15:54<8:26:25,  2.60it/s] 

Epoch [21000/100000]


 21%|██        | 21000/100000 [2:15:56<16:20:19,  1.34it/s]

Accuracy:0.5591
Precision:0.5402
Recall:0.7944
F1-Score:0.6431
ROC-AUC:0.5818

Confusion Matrix:
 [[5084 1316]
 [4327 2073]]

Classification Report:
               precision    recall  f1-score   support

           1       0.54      0.79      0.64      6400
           2       0.61      0.32      0.42      6400

    accuracy                           0.56     12800
   macro avg       0.58      0.56      0.53     12800
weighted avg       0.58      0.56      0.53     12800



 22%|██▏       | 21999/100000 [2:22:26<8:14:03,  2.63it/s] 

Epoch [22000/100000]


 22%|██▏       | 22000/100000 [2:22:27<15:23:05,  1.41it/s]

Accuracy:0.6136
Precision:0.5792
Recall:0.8306
F1-Score:0.6825
ROC-AUC:0.6686

Confusion Matrix:
 [[5316 1084]
 [3862 2538]]

Classification Report:
               precision    recall  f1-score   support

           1       0.58      0.83      0.68      6400
           2       0.70      0.40      0.51      6400

    accuracy                           0.61     12800
   macro avg       0.64      0.61      0.59     12800
weighted avg       0.64      0.61      0.59     12800



 23%|██▎       | 22999/100000 [2:28:55<8:15:25,  2.59it/s] 

Epoch [23000/100000]


 23%|██▎       | 23000/100000 [2:28:57<16:21:30,  1.31it/s]

Accuracy:0.5926
Precision:0.5531
Recall:0.9647
F1-Score:0.7031
ROC-AUC:0.6740

Confusion Matrix:
 [[6174  226]
 [4989 1411]]

Classification Report:
               precision    recall  f1-score   support

           1       0.55      0.96      0.70      6400
           2       0.86      0.22      0.35      6400

    accuracy                           0.59     12800
   macro avg       0.71      0.59      0.53     12800
weighted avg       0.71      0.59      0.53     12800



 24%|██▍       | 23999/100000 [2:35:32<8:21:09,  2.53it/s] 

Epoch [24000/100000]


 24%|██▍       | 24000/100000 [2:35:33<15:57:37,  1.32it/s]

Accuracy:0.6134
Precision:0.5784
Recall:0.8370
F1-Score:0.6841
ROC-AUC:0.6764

Confusion Matrix:
 [[5357 1043]
 [3905 2495]]

Classification Report:
               precision    recall  f1-score   support

           1       0.58      0.84      0.68      6400
           2       0.71      0.39      0.50      6400

    accuracy                           0.61     12800
   macro avg       0.64      0.61      0.59     12800
weighted avg       0.64      0.61      0.59     12800



 25%|██▍       | 24999/100000 [2:42:07<7:58:26,  2.61it/s] 

Epoch [25000/100000]


 25%|██▌       | 25000/100000 [2:42:09<15:35:11,  1.34it/s]

Accuracy:0.6098
Precision:0.5723
Recall:0.8700
F1-Score:0.6904
ROC-AUC:0.7001

Confusion Matrix:
 [[5568  832]
 [4162 2238]]

Classification Report:
               precision    recall  f1-score   support

           1       0.57      0.87      0.69      6400
           2       0.73      0.35      0.47      6400

    accuracy                           0.61     12800
   macro avg       0.65      0.61      0.58     12800
weighted avg       0.65      0.61      0.58     12800



 26%|██▌       | 25999/100000 [2:48:40<8:11:07,  2.51it/s] 

Epoch [26000/100000]


 26%|██▌       | 26000/100000 [2:48:42<16:42:24,  1.23it/s]

Accuracy:0.6116
Precision:0.5789
Recall:0.8194
F1-Score:0.6784
ROC-AUC:0.6490

Confusion Matrix:
 [[5244 1156]
 [3815 2585]]

Classification Report:
               precision    recall  f1-score   support

           1       0.58      0.82      0.68      6400
           2       0.69      0.40      0.51      6400

    accuracy                           0.61     12800
   macro avg       0.63      0.61      0.59     12800
weighted avg       0.63      0.61      0.59     12800



 27%|██▋       | 26999/100000 [2:55:06<7:27:39,  2.72it/s] 

Epoch [27000/100000]


 27%|██▋       | 27000/100000 [2:55:08<13:33:59,  1.49it/s]

Accuracy:0.6155
Precision:0.5834
Recall:0.8080
F1-Score:0.6776
ROC-AUC:0.6855

Confusion Matrix:
 [[5171 1229]
 [3692 2708]]

Classification Report:
               precision    recall  f1-score   support

           1       0.58      0.81      0.68      6400
           2       0.69      0.42      0.52      6400

    accuracy                           0.62     12800
   macro avg       0.64      0.62      0.60     12800
weighted avg       0.64      0.62      0.60     12800



 28%|██▊       | 27999/100000 [3:01:39<7:58:05,  2.51it/s] 

Epoch [28000/100000]


 28%|██▊       | 28000/100000 [3:01:40<14:33:17,  1.37it/s]

Accuracy:0.6343
Precision:0.5975
Recall:0.8228
F1-Score:0.6923
ROC-AUC:0.7152

Confusion Matrix:
 [[5266 1134]
 [3547 2853]]

Classification Report:
               precision    recall  f1-score   support

           1       0.60      0.82      0.69      6400
           2       0.72      0.45      0.55      6400

    accuracy                           0.63     12800
   macro avg       0.66      0.63      0.62     12800
weighted avg       0.66      0.63      0.62     12800



 29%|██▉       | 28999/100000 [3:08:07<7:40:29,  2.57it/s] 

Epoch [29000/100000]


 29%|██▉       | 29000/100000 [3:08:09<14:48:38,  1.33it/s]

Accuracy:0.5923
Precision:0.5541
Recall:0.9455
F1-Score:0.6987
ROC-AUC:0.6909

Confusion Matrix:
 [[6051  349]
 [4869 1531]]

Classification Report:
               precision    recall  f1-score   support

           1       0.55      0.95      0.70      6400
           2       0.81      0.24      0.37      6400

    accuracy                           0.59     12800
   macro avg       0.68      0.59      0.53     12800
weighted avg       0.68      0.59      0.53     12800



 30%|██▉       | 29999/100000 [3:14:43<8:01:54,  2.42it/s] 

Epoch [30000/100000]


 30%|███       | 30000/100000 [3:14:44<14:27:36,  1.34it/s]

Accuracy:0.6059
Precision:0.5732
Recall:0.8295
F1-Score:0.6779
ROC-AUC:0.6804

Confusion Matrix:
 [[5309 1091]
 [3953 2447]]

Classification Report:
               precision    recall  f1-score   support

           1       0.57      0.83      0.68      6400
           2       0.69      0.38      0.49      6400

    accuracy                           0.61     12800
   macro avg       0.63      0.61      0.59     12800
weighted avg       0.63      0.61      0.59     12800



 31%|███       | 30999/100000 [3:21:17<7:23:37,  2.59it/s] 

Epoch [31000/100000]


 31%|███       | 31000/100000 [3:21:18<13:52:31,  1.38it/s]

Accuracy:0.6205
Precision:0.5879
Recall:0.8058
F1-Score:0.6798
ROC-AUC:0.6795

Confusion Matrix:
 [[5157 1243]
 [3615 2785]]

Classification Report:
               precision    recall  f1-score   support

           1       0.59      0.81      0.68      6400
           2       0.69      0.44      0.53      6400

    accuracy                           0.62     12800
   macro avg       0.64      0.62      0.61     12800
weighted avg       0.64      0.62      0.61     12800



 32%|███▏      | 31999/100000 [3:27:45<7:55:41,  2.38it/s] 

Epoch [32000/100000]


 32%|███▏      | 32000/100000 [3:27:46<13:52:20,  1.36it/s]

Accuracy:0.5850
Precision:0.5522
Recall:0.8994
F1-Score:0.6843
ROC-AUC:0.6610

Confusion Matrix:
 [[5756  644]
 [4668 1732]]

Classification Report:
               precision    recall  f1-score   support

           1       0.55      0.90      0.68      6400
           2       0.73      0.27      0.39      6400

    accuracy                           0.58     12800
   macro avg       0.64      0.58      0.54     12800
weighted avg       0.64      0.58      0.54     12800



 33%|███▎      | 32999/100000 [3:34:08<7:09:15,  2.60it/s] 

Epoch [33000/100000]


 33%|███▎      | 33000/100000 [3:34:10<13:55:47,  1.34it/s]

Accuracy:0.5850
Precision:0.5469
Recall:0.9919
F1-Score:0.7050
ROC-AUC:0.6796

Confusion Matrix:
 [[6348   52]
 [5260 1140]]

Classification Report:
               precision    recall  f1-score   support

           1       0.55      0.99      0.71      6400
           2       0.96      0.18      0.30      6400

    accuracy                           0.58     12800
   macro avg       0.75      0.58      0.50     12800
weighted avg       0.75      0.58      0.50     12800



 34%|███▍      | 33999/100000 [3:40:42<7:04:12,  2.59it/s] 

Epoch [34000/100000]


 34%|███▍      | 34000/100000 [3:40:44<13:46:03,  1.33it/s]

Accuracy:0.6011
Precision:0.5602
Recall:0.9414
F1-Score:0.7024
ROC-AUC:0.6477

Confusion Matrix:
 [[6025  375]
 [4731 1669]]

Classification Report:
               precision    recall  f1-score   support

           1       0.56      0.94      0.70      6400
           2       0.82      0.26      0.40      6400

    accuracy                           0.60     12800
   macro avg       0.69      0.60      0.55     12800
weighted avg       0.69      0.60      0.55     12800



 35%|███▍      | 34999/100000 [3:47:09<6:47:45,  2.66it/s] 

Epoch [35000/100000]


 35%|███▌      | 35000/100000 [3:47:10<12:35:32,  1.43it/s]

Accuracy:0.5834
Precision:0.5492
Recall:0.9303
F1-Score:0.6907
ROC-AUC:0.6612

Confusion Matrix:
 [[5954  446]
 [4887 1513]]

Classification Report:
               precision    recall  f1-score   support

           1       0.55      0.93      0.69      6400
           2       0.77      0.24      0.36      6400

    accuracy                           0.58     12800
   macro avg       0.66      0.58      0.53     12800
weighted avg       0.66      0.58      0.53     12800



 36%|███▌      | 35999/100000 [3:53:30<6:58:09,  2.55it/s] 

Epoch [36000/100000]


 36%|███▌      | 36000/100000 [3:53:32<13:07:26,  1.35it/s]

Accuracy:0.6151
Precision:0.5812
Recall:0.8237
F1-Score:0.6815
ROC-AUC:0.6907

Confusion Matrix:
 [[5272 1128]
 [3799 2601]]

Classification Report:
               precision    recall  f1-score   support

           1       0.58      0.82      0.68      6400
           2       0.70      0.41      0.51      6400

    accuracy                           0.62     12800
   macro avg       0.64      0.62      0.60     12800
weighted avg       0.64      0.62      0.60     12800



 37%|███▋      | 36999/100000 [3:59:58<6:37:02,  2.64it/s] 

Epoch [37000/100000]


 37%|███▋      | 37000/100000 [4:00:00<12:54:53,  1.36it/s]

Accuracy:0.5593
Precision:0.5317
Recall:0.9953
F1-Score:0.6931
ROC-AUC:0.6347

Confusion Matrix:
 [[6370   30]
 [5611  789]]

Classification Report:
               precision    recall  f1-score   support

           1       0.53      1.00      0.69      6400
           2       0.96      0.12      0.22      6400

    accuracy                           0.56     12800
   macro avg       0.75      0.56      0.46     12800
weighted avg       0.75      0.56      0.46     12800



 38%|███▊      | 37999/100000 [4:06:17<6:18:38,  2.73it/s] 

Epoch [38000/100000]


 38%|███▊      | 38000/100000 [4:06:19<11:33:31,  1.49it/s]

Accuracy:0.5548
Precision:0.5407
Recall:0.7280
F1-Score:0.6205
ROC-AUC:0.5914

Confusion Matrix:
 [[4659 1741]
 [3958 2442]]

Classification Report:
               precision    recall  f1-score   support

           1       0.54      0.73      0.62      6400
           2       0.58      0.38      0.46      6400

    accuracy                           0.55     12800
   macro avg       0.56      0.55      0.54     12800
weighted avg       0.56      0.55      0.54     12800



 39%|███▉      | 38999/100000 [4:12:57<6:22:23,  2.66it/s] 

Epoch [39000/100000]


 39%|███▉      | 39000/100000 [4:12:58<11:46:21,  1.44it/s]

Accuracy:0.6044
Precision:0.5668
Recall:0.8859
F1-Score:0.6913
ROC-AUC:0.6790

Confusion Matrix:
 [[5670  730]
 [4334 2066]]

Classification Report:
               precision    recall  f1-score   support

           1       0.57      0.89      0.69      6400
           2       0.74      0.32      0.45      6400

    accuracy                           0.60     12800
   macro avg       0.65      0.60      0.57     12800
weighted avg       0.65      0.60      0.57     12800



 40%|███▉      | 39999/100000 [4:19:17<6:09:54,  2.70it/s] 

Epoch [40000/100000]


 40%|████      | 40000/100000 [4:19:18<11:31:40,  1.45it/s]

Accuracy:0.5964
Precision:0.5611
Recall:0.8856
F1-Score:0.6869
ROC-AUC:0.6984

Confusion Matrix:
 [[5668  732]
 [4434 1966]]

Classification Report:
               precision    recall  f1-score   support

           1       0.56      0.89      0.69      6400
           2       0.73      0.31      0.43      6400

    accuracy                           0.60     12800
   macro avg       0.64      0.60      0.56     12800
weighted avg       0.64      0.60      0.56     12800



 41%|████      | 40999/100000 [4:25:36<6:06:58,  2.68it/s] 

Epoch [41000/100000]


 41%|████      | 41000/100000 [4:25:38<11:38:45,  1.41it/s]

Accuracy:0.5557
Precision:0.5296
Recall:0.9969
F1-Score:0.6917
ROC-AUC:0.5728

Confusion Matrix:
 [[6380   20]
 [5667  733]]

Classification Report:
               precision    recall  f1-score   support

           1       0.53      1.00      0.69      6400
           2       0.97      0.11      0.20      6400

    accuracy                           0.56     12800
   macro avg       0.75      0.56      0.45     12800
weighted avg       0.75      0.56      0.45     12800



 42%|████▏     | 41999/100000 [4:32:11<6:11:44,  2.60it/s] 

Epoch [42000/100000]


 42%|████▏     | 42000/100000 [4:32:13<11:56:49,  1.35it/s]

Accuracy:0.5929
Precision:0.5549
Recall:0.9395
F1-Score:0.6977
ROC-AUC:0.6558

Confusion Matrix:
 [[6013  387]
 [4824 1576]]

Classification Report:
               precision    recall  f1-score   support

           1       0.55      0.94      0.70      6400
           2       0.80      0.25      0.38      6400

    accuracy                           0.59     12800
   macro avg       0.68      0.59      0.54     12800
weighted avg       0.68      0.59      0.54     12800



 43%|████▎     | 42999/100000 [4:38:39<5:58:17,  2.65it/s] 

Epoch [43000/100000]


 43%|████▎     | 43000/100000 [4:38:41<10:50:01,  1.46it/s]

Accuracy:0.5689
Precision:0.5389
Recall:0.9545
F1-Score:0.6889
ROC-AUC:0.6827

Confusion Matrix:
 [[6109  291]
 [5227 1173]]

Classification Report:
               precision    recall  f1-score   support

           1       0.54      0.95      0.69      6400
           2       0.80      0.18      0.30      6400

    accuracy                           0.57     12800
   macro avg       0.67      0.57      0.49     12800
weighted avg       0.67      0.57      0.49     12800



 44%|████▍     | 43999/100000 [4:45:01<5:46:56,  2.69it/s] 

Epoch [44000/100000]


 44%|████▍     | 44000/100000 [4:45:03<10:43:13,  1.45it/s]

Accuracy:0.5665
Precision:0.5357
Recall:0.9988
F1-Score:0.6973
ROC-AUC:0.6771

Confusion Matrix:
 [[6392    8]
 [5541  859]]

Classification Report:
               precision    recall  f1-score   support

           1       0.54      1.00      0.70      6400
           2       0.99      0.13      0.24      6400

    accuracy                           0.57     12800
   macro avg       0.76      0.57      0.47     12800
weighted avg       0.76      0.57      0.47     12800



 45%|████▍     | 44999/100000 [4:51:20<5:59:25,  2.55it/s] 

Epoch [45000/100000]


 45%|████▌     | 45000/100000 [4:51:22<12:16:06,  1.25it/s]

Accuracy:0.5719
Precision:0.5440
Recall:0.8895
F1-Score:0.6751
ROC-AUC:0.6465

Confusion Matrix:
 [[5693  707]
 [4773 1627]]

Classification Report:
               precision    recall  f1-score   support

           1       0.54      0.89      0.68      6400
           2       0.70      0.25      0.37      6400

    accuracy                           0.57     12800
   macro avg       0.62      0.57      0.52     12800
weighted avg       0.62      0.57      0.52     12800



 46%|████▌     | 45999/100000 [4:57:46<5:28:47,  2.74it/s] 

Epoch [46000/100000]


 46%|████▌     | 46000/100000 [4:57:47<11:04:29,  1.35it/s]

Accuracy:0.5448
Precision:0.5236
Recall:0.9952
F1-Score:0.6861
ROC-AUC:0.6100

Confusion Matrix:
 [[6369   31]
 [5796  604]]

Classification Report:
               precision    recall  f1-score   support

           1       0.52      1.00      0.69      6400
           2       0.95      0.09      0.17      6400

    accuracy                           0.54     12800
   macro avg       0.74      0.54      0.43     12800
weighted avg       0.74      0.54      0.43     12800



 47%|████▋     | 46999/100000 [5:04:06<5:29:51,  2.68it/s] 

Epoch [47000/100000]


 47%|████▋     | 47000/100000 [5:04:07<10:52:55,  1.35it/s]

Accuracy:0.5725
Precision:0.5394
Recall:0.9931
F1-Score:0.6991
ROC-AUC:0.6902

Confusion Matrix:
 [[6356   44]
 [5428  972]]

Classification Report:
               precision    recall  f1-score   support

           1       0.54      0.99      0.70      6400
           2       0.96      0.15      0.26      6400

    accuracy                           0.57     12800
   macro avg       0.75      0.57      0.48     12800
weighted avg       0.75      0.57      0.48     12800



 48%|████▊     | 47999/100000 [5:10:26<5:17:09,  2.73it/s] 

Epoch [48000/100000]


 48%|████▊     | 48000/100000 [5:10:28<10:33:39,  1.37it/s]

Accuracy:0.5608
Precision:0.5327
Recall:0.9897
F1-Score:0.6926
ROC-AUC:0.6789

Confusion Matrix:
 [[6334   66]
 [5556  844]]

Classification Report:
               precision    recall  f1-score   support

           1       0.53      0.99      0.69      6400
           2       0.93      0.13      0.23      6400

    accuracy                           0.56     12800
   macro avg       0.73      0.56      0.46     12800
weighted avg       0.73      0.56      0.46     12800



 49%|████▉     | 48999/100000 [5:16:46<5:22:17,  2.64it/s] 

Epoch [49000/100000]


 49%|████▉     | 49000/100000 [5:16:48<10:15:59,  1.38it/s]

Accuracy:0.5824
Precision:0.5505
Recall:0.8981
F1-Score:0.6826
ROC-AUC:0.6518

Confusion Matrix:
 [[5748  652]
 [4693 1707]]

Classification Report:
               precision    recall  f1-score   support

           1       0.55      0.90      0.68      6400
           2       0.72      0.27      0.39      6400

    accuracy                           0.58     12800
   macro avg       0.64      0.58      0.54     12800
weighted avg       0.64      0.58      0.54     12800



 50%|████▉     | 49999/100000 [5:23:04<5:18:14,  2.62it/s] 

Epoch [50000/100000]


 50%|█████     | 50000/100000 [5:23:06<10:44:05,  1.29it/s]

Accuracy:0.6289
Precision:0.5856
Recall:0.8822
F1-Score:0.7039
ROC-AUC:0.6784

Confusion Matrix:
 [[5646  754]
 [3996 2404]]

Classification Report:
               precision    recall  f1-score   support

           1       0.59      0.88      0.70      6400
           2       0.76      0.38      0.50      6400

    accuracy                           0.63     12800
   macro avg       0.67      0.63      0.60     12800
weighted avg       0.67      0.63      0.60     12800



 51%|█████     | 50999/100000 [5:29:32<5:11:09,  2.62it/s] 

Epoch [51000/100000]


 51%|█████     | 51000/100000 [5:29:34<10:05:14,  1.35it/s]

Accuracy:0.5434
Precision:0.5229
Recall:0.9920
F1-Score:0.6848
ROC-AUC:0.6472

Confusion Matrix:
 [[6349   51]
 [5794  606]]

Classification Report:
               precision    recall  f1-score   support

           1       0.52      0.99      0.68      6400
           2       0.92      0.09      0.17      6400

    accuracy                           0.54     12800
   macro avg       0.72      0.54      0.43     12800
weighted avg       0.72      0.54      0.43     12800



 52%|█████▏    | 51999/100000 [5:35:50<5:05:20,  2.62it/s] 

Epoch [52000/100000]


 52%|█████▏    | 52000/100000 [5:35:51<9:54:28,  1.35it/s]

Accuracy:0.5748
Precision:0.5414
Recall:0.9788
F1-Score:0.6972
ROC-AUC:0.6749

Confusion Matrix:
 [[6264  136]
 [5306 1094]]

Classification Report:
               precision    recall  f1-score   support

           1       0.54      0.98      0.70      6400
           2       0.89      0.17      0.29      6400

    accuracy                           0.57     12800
   macro avg       0.72      0.57      0.49     12800
weighted avg       0.72      0.57      0.49     12800



 53%|█████▎    | 52999/100000 [5:42:16<5:19:04,  2.46it/s]

Epoch [53000/100000]


 53%|█████▎    | 53000/100000 [5:42:18<10:27:06,  1.25it/s]

Accuracy:0.5881
Precision:0.5534
Recall:0.9125
F1-Score:0.6890
ROC-AUC:0.6680

Confusion Matrix:
 [[5840  560]
 [4712 1688]]

Classification Report:
               precision    recall  f1-score   support

           1       0.55      0.91      0.69      6400
           2       0.75      0.26      0.39      6400

    accuracy                           0.59     12800
   macro avg       0.65      0.59      0.54     12800
weighted avg       0.65      0.59      0.54     12800



 54%|█████▍    | 53999/100000 [5:48:40<4:43:10,  2.71it/s] 

Epoch [54000/100000]


 54%|█████▍    | 54000/100000 [5:48:41<9:27:27,  1.35it/s]

Accuracy:0.5387
Precision:0.5202
Recall:0.9992
F1-Score:0.6842
ROC-AUC:0.5983

Confusion Matrix:
 [[6395    5]
 [5899  501]]

Classification Report:
               precision    recall  f1-score   support

           1       0.52      1.00      0.68      6400
           2       0.99      0.08      0.15      6400

    accuracy                           0.54     12800
   macro avg       0.76      0.54      0.41     12800
weighted avg       0.76      0.54      0.41     12800



 55%|█████▍    | 54999/100000 [5:54:58<5:08:05,  2.43it/s]

Epoch [55000/100000]


 55%|█████▌    | 55000/100000 [5:55:00<10:07:46,  1.23it/s]

Accuracy:0.5493
Precision:0.5260
Recall:0.9973
F1-Score:0.6888
ROC-AUC:0.6875

Confusion Matrix:
 [[6383   17]
 [5752  648]]

Classification Report:
               precision    recall  f1-score   support

           1       0.53      1.00      0.69      6400
           2       0.97      0.10      0.18      6400

    accuracy                           0.55     12800
   macro avg       0.75      0.55      0.44     12800
weighted avg       0.75      0.55      0.44     12800



 56%|█████▌    | 55999/100000 [6:01:22<4:39:51,  2.62it/s] 

Epoch [56000/100000]


 56%|█████▌    | 56000/100000 [6:01:23<9:35:03,  1.28it/s]

Accuracy:0.6038
Precision:0.5667
Recall:0.8819
F1-Score:0.6900
ROC-AUC:0.6948

Confusion Matrix:
 [[5644  756]
 [4316 2084]]

Classification Report:
               precision    recall  f1-score   support

           1       0.57      0.88      0.69      6400
           2       0.73      0.33      0.45      6400

    accuracy                           0.60     12800
   macro avg       0.65      0.60      0.57     12800
weighted avg       0.65      0.60      0.57     12800



 57%|█████▋    | 56999/100000 [6:07:43<4:50:57,  2.46it/s]

Epoch [57000/100000]


 57%|█████▋    | 57000/100000 [6:07:45<8:52:51,  1.34it/s]

Accuracy:0.5474
Precision:0.5274
Recall:0.9113
F1-Score:0.6682
ROC-AUC:0.6562

Confusion Matrix:
 [[5832  568]
 [5225 1175]]

Classification Report:
               precision    recall  f1-score   support

           1       0.53      0.91      0.67      6400
           2       0.67      0.18      0.29      6400

    accuracy                           0.55     12800
   macro avg       0.60      0.55      0.48     12800
weighted avg       0.60      0.55      0.48     12800



 58%|█████▊    | 57999/100000 [6:13:54<4:13:13,  2.76it/s]

Epoch [58000/100000]


 58%|█████▊    | 58000/100000 [6:13:55<8:21:38,  1.40it/s]

Accuracy:0.5367
Precision:0.5191
Recall:1.0000
F1-Score:0.6834
ROC-AUC:0.5775

Confusion Matrix:
 [[6400    0]
 [5930  470]]

Classification Report:
               precision    recall  f1-score   support

           1       0.52      1.00      0.68      6400
           2       1.00      0.07      0.14      6400

    accuracy                           0.54     12800
   macro avg       0.76      0.54      0.41     12800
weighted avg       0.76      0.54      0.41     12800



 59%|█████▉    | 58999/100000 [6:20:15<4:07:44,  2.76it/s]

Epoch [59000/100000]


 59%|█████▉    | 59000/100000 [6:20:17<8:40:21,  1.31it/s]

Accuracy:0.5613
Precision:0.5327
Recall:0.9998
F1-Score:0.6951
ROC-AUC:0.6690

Confusion Matrix:
 [[6399    1]
 [5614  786]]

Classification Report:
               precision    recall  f1-score   support

           1       0.53      1.00      0.70      6400
           2       1.00      0.12      0.22      6400

    accuracy                           0.56     12800
   macro avg       0.77      0.56      0.46     12800
weighted avg       0.77      0.56      0.46     12800



 60%|█████▉    | 59999/100000 [6:26:31<4:02:37,  2.75it/s]

Epoch [60000/100000]


 60%|██████    | 60000/100000 [6:26:32<7:14:01,  1.54it/s]

Accuracy:0.5352
Precision:0.5183
Recall:0.9994
F1-Score:0.6826
ROC-AUC:0.6993

Confusion Matrix:
 [[6396    4]
 [5945  455]]

Classification Report:
               precision    recall  f1-score   support

           1       0.52      1.00      0.68      6400
           2       0.99      0.07      0.13      6400

    accuracy                           0.54     12800
   macro avg       0.75      0.54      0.41     12800
weighted avg       0.75      0.54      0.41     12800



 61%|██████    | 60999/100000 [6:32:41<3:57:05,  2.74it/s]

Epoch [61000/100000]


 61%|██████    | 61000/100000 [6:32:43<7:34:15,  1.43it/s]

Accuracy:0.5798
Precision:0.5508
Recall:0.8655
F1-Score:0.6732
ROC-AUC:0.6408

Confusion Matrix:
 [[5539  861]
 [4517 1883]]

Classification Report:
               precision    recall  f1-score   support

           1       0.55      0.87      0.67      6400
           2       0.69      0.29      0.41      6400

    accuracy                           0.58     12800
   macro avg       0.62      0.58      0.54     12800
weighted avg       0.62      0.58      0.54     12800



 62%|██████▏   | 61999/100000 [6:38:57<3:54:12,  2.70it/s] 

Epoch [62000/100000]


 62%|██████▏   | 62000/100000 [6:38:58<8:09:27,  1.29it/s]

Accuracy:0.5436
Precision:0.5228
Recall:1.0000
F1-Score:0.6866
ROC-AUC:0.5532

Confusion Matrix:
 [[6400    0]
 [5842  558]]

Classification Report:
               precision    recall  f1-score   support

           1       0.52      1.00      0.69      6400
           2       1.00      0.09      0.16      6400

    accuracy                           0.54     12800
   macro avg       0.76      0.54      0.42     12800
weighted avg       0.76      0.54      0.42     12800



 63%|██████▎   | 62999/100000 [6:45:22<3:55:25,  2.62it/s]

Epoch [63000/100000]


 63%|██████▎   | 63000/100000 [6:45:23<7:31:14,  1.37it/s]

Accuracy:0.5352
Precision:0.5182
Recall:1.0000
F1-Score:0.6827
ROC-AUC:0.5312

Confusion Matrix:
 [[6400    0]
 [5950  450]]

Classification Report:
               precision    recall  f1-score   support

           1       0.52      1.00      0.68      6400
           2       1.00      0.07      0.13      6400

    accuracy                           0.54     12800
   macro avg       0.76      0.54      0.41     12800
weighted avg       0.76      0.54      0.41     12800



 64%|██████▍   | 63999/100000 [6:51:42<3:39:28,  2.73it/s]

Epoch [64000/100000]


 64%|██████▍   | 64000/100000 [6:51:43<6:40:58,  1.50it/s]

Accuracy:0.5501
Precision:0.5264
Recall:0.9998
F1-Score:0.6897
ROC-AUC:0.6237

Confusion Matrix:
 [[6399    1]
 [5758  642]]

Classification Report:
               precision    recall  f1-score   support

           1       0.53      1.00      0.69      6400
           2       1.00      0.10      0.18      6400

    accuracy                           0.55     12800
   macro avg       0.76      0.55      0.44     12800
weighted avg       0.76      0.55      0.44     12800



 65%|██████▍   | 64999/100000 [6:58:03<3:31:16,  2.76it/s]

Epoch [65000/100000]


 65%|██████▌   | 65000/100000 [6:58:05<7:03:25,  1.38it/s]

Accuracy:0.5200
Precision:0.5102
Recall:1.0000
F1-Score:0.6757
ROC-AUC:0.5563

Confusion Matrix:
 [[6400    0]
 [6144  256]]

Classification Report:
               precision    recall  f1-score   support

           1       0.51      1.00      0.68      6400
           2       1.00      0.04      0.08      6400

    accuracy                           0.52     12800
   macro avg       0.76      0.52      0.38     12800
weighted avg       0.76      0.52      0.38     12800



 66%|██████▌   | 65999/100000 [7:04:22<3:35:54,  2.62it/s]

Epoch [66000/100000]


 66%|██████▌   | 66000/100000 [7:04:23<6:52:49,  1.37it/s]

Accuracy:0.5413
Precision:0.5225
Recall:0.9569
F1-Score:0.6759
ROC-AUC:0.6472

Confusion Matrix:
 [[6124  276]
 [5596  804]]

Classification Report:
               precision    recall  f1-score   support

           1       0.52      0.96      0.68      6400
           2       0.74      0.13      0.21      6400

    accuracy                           0.54     12800
   macro avg       0.63      0.54      0.45     12800
weighted avg       0.63      0.54      0.45     12800



 67%|██████▋   | 66999/100000 [7:10:36<3:19:35,  2.76it/s]

Epoch [67000/100000]


 67%|██████▋   | 67000/100000 [7:10:38<6:43:09,  1.36it/s]

Accuracy:0.5483
Precision:0.5254
Recall:1.0000
F1-Score:0.6888
ROC-AUC:0.6081

Confusion Matrix:
 [[6400    0]
 [5782  618]]

Classification Report:
               precision    recall  f1-score   support

           1       0.53      1.00      0.69      6400
           2       1.00      0.10      0.18      6400

    accuracy                           0.55     12800
   macro avg       0.76      0.55      0.43     12800
weighted avg       0.76      0.55      0.43     12800



 68%|██████▊   | 67999/100000 [7:16:53<3:21:18,  2.65it/s]

Epoch [68000/100000]


 68%|██████▊   | 68000/100000 [7:16:55<6:04:34,  1.46it/s]

Accuracy:0.5528
Precision:0.5279
Recall:1.0000
F1-Score:0.6910
ROC-AUC:0.5533

Confusion Matrix:
 [[6400    0]
 [5724  676]]

Classification Report:
               precision    recall  f1-score   support

           1       0.53      1.00      0.69      6400
           2       1.00      0.11      0.19      6400

    accuracy                           0.55     12800
   macro avg       0.76      0.55      0.44     12800
weighted avg       0.76      0.55      0.44     12800



 69%|██████▉   | 68999/100000 [7:23:10<3:10:33,  2.71it/s]

Epoch [69000/100000]


 69%|██████▉   | 69000/100000 [7:23:11<6:04:38,  1.42it/s]

Accuracy:0.5406
Precision:0.5212
Recall:1.0000
F1-Score:0.6852
ROC-AUC:0.6536

Confusion Matrix:
 [[6400    0]
 [5880  520]]

Classification Report:
               precision    recall  f1-score   support

           1       0.52      1.00      0.69      6400
           2       1.00      0.08      0.15      6400

    accuracy                           0.54     12800
   macro avg       0.76      0.54      0.42     12800
weighted avg       0.76      0.54      0.42     12800



 69%|██████▉   | 69256/100000 [7:24:50<3:08:23,  2.72it/s]

In [None]:
model = ESM1b(tax_vocab_sizes).to(device)
optimizers = model.get_optimizers()


trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f'Trainable parameters: {trainable/ 1e6} M')
print(f'Total parameters: {total/ 1e6} M')
print(model)

# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
criterion = nn.CrossEntropyLoss()
# Cosine annealing with warm restarts
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizers['0'],
    T_0=10,  # Initial restart interval
    T_mult=2,  # Multiply interval by 2 after each restart
    eta_min=1e-6  # Minimum learning rate
)

In [None]:
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score

val_dir = f"val_results/{model_name}"
if not os.path.exists(val_dir):
    os.makedirs(val_dir)
    
val_batches = [virus_da.get_batch() for _ in range(num_val // 2)] + [cellular_da.get_batch() for _ in range(num_val // 2)]

input_sequences = [e['Sequence'] for b in val_batches for e in b]
labels_ = [encode_lineage(e['Taxonomic_lineage__ALL_'])  for b in val_batches for e in b]

def evaluate(model):
    model.eval()  # Set model to evaluation mode
    
    df = {
        i : {
            "sequence": [],
            "label": [],
            "pred": [],
            "loss": []
        } for i in tax_vocab_sizes.keys()
    }

    metrics = {
        i : {
            "loss": 0,
            "accuracy": 0,
            "f1 macro": 0,
            "f1 micro": 0
        } for i in tax_vocab_sizes.keys()
    }
    
    # Process each sequence
    for sequence, label in zip(input_sequences, labels_):
        inputs = tokenizer_(
            [sequence],
            return_tensors="pt",
            padding='max_length',
            truncation=True,
            max_length=max_seq_len
        ).to(device)
    
        # Get model output
        with torch.no_grad():
            output = model(inputs['input_ids'], inputs['attention_mask'])

        for k in tax_vocab_sizes.keys():
            pred = output[str(k)].argmax(dim=-1).cpu().item()
            loss = criterion(output[str(k)], torch.tensor([label[k]]).to(device))
            df[k]["sequence"].append(sequence)
            df[k]["label"].append(level_decoder[k][label[k]])
            df[k]["pred"].append(level_decoder[k][pred])
            df[k]["loss"].append(round(loss.cpu().item(), 4))

    for k in tax_vocab_sizes.keys():
        # Convert to DataFrame
        new_df = pd.DataFrame(df[k])
        new_df['is_incorrect'] = new_df['label'] != new_df['pred']
        new_df = new_df.sort_values(['is_incorrect', 'loss'], ascending=[False, False])
        new_df.to_csv(f'val_results/{model_name}/classification_results_{k}.csv', index=False)

        metrics[k]["loss"] = np.array(df[k]["loss"]).mean()
        metrics[k]["accuracy"] = accuracy_score(np.array(df[k]["label"]), np.array(df[k]["pred"]))
        metrics[k]["f1 macro"] = f1_score(np.array(df[k]["label"]), np.array(df[k]["pred"]), average='macro')  # F1-score for multi-label classification
        metrics[k]["f1 micro"] = f1_score(np.array(df[k]["label"]), np.array(df[k]["pred"]), average='micro') 
    
    return metrics

In [None]:
evaluate(model)

In [None]:
running_loss = 0
current_lr = lr

for epoch in tqdm(range(epochs)):
    model.train()

    tensor_batch = mix_data_to_tensor_batch(virus_da.get_batch(), cellular_da.get_batch(), partition = get_partition_ratio(epoch+1))
    tensor_batch.gpu(device)
    
    labels = tensor_batch.taxes

    batch_loss = 0
    for index, optimizer in optimizers.items():
        optimizer.zero_grad()
        output = model(tensor_batch.seq_ids['input_ids'], tensor_batch.seq_ids['attention_mask'], index=index)
        loss = criterion(output, labels[int(index)])
        loss.backward()
        optimizer.step()
        batch_loss += loss
    
    running_loss += batch_loss.item()
    
    if (epoch + 1) % val_epoch == 0:
        train_loss = running_loss / val_epoch
        val_metrics = evaluate(model)
        val_losses = {k: v["loss"] for k, v in val_metrics.items()} 
        val_loss = sum([entry['loss'] for entry in val_metrics.values()]) 
        print(f"Epoch [{epoch + 1}/{epochs}]")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}")
        print(val_losses)
        
        # Create metrics dictionary for saving
        metrics = {
            "train_loss": train_loss,
            "val_loss": val_loss,
            "epoch": epoch + 1,
            "lr": current_lr,
            "partition": get_partition_ratio(epoch+1)
        }
        metrics.update(val_losses)
            
        # Save periodic checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizers,
            'scheduler_state_dict': scheduler.state_dict(),
            'metrics': metrics
        }, checkpoint_path)
        
        # Log to wandb
        wandb.log(metrics)

        # Step the scheduler
        scheduler.step(epoch + batch_loss.item())
        current_lr = scheduler.get_last_lr()[0]
        
        # Reset training metrics
        running_loss = 0
wandb.finish()