In [3]:
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.model_selection import KFold

In [4]:
import pandas as pd

In [5]:
# hic = pd.read_csv('data/Oligo_NN.loop_gene.csv')
#  # drop 'strand', 'gene_type'
# hic = hic.drop(['gene_id','strand','gene_type'], axis=1)
# # for Qanova	Eanova	Tanova, keep only 2 decimal places
# hic['Qanova'] = hic['Qanova'].apply(lambda x: round(x,1))
# hic['Eanova'] = hic['Eanova'].apply(lambda x: round(x,1))
# hic['Tanova'] = hic['Tanova'].apply(lambda x: round(x,1))
# hic
# #
# hic.head()
# hic.to_csv('data/Oligo_NN.loop_gene.csv')

In [6]:
from data_loader import load_data, get_balanced_data, normalize_features

In [12]:
data = load_data()
X_balanced, y_balanced = get_balanced_data(data)
print(len(X_balanced['mcg']), len(X_balanced['atac']))

Processed mcg data
Processed genebody data
Processed atac data
Processed hic data
zero: 9941, non-zero: 3183
4244 4244


In [9]:
# Mean Accuracy: 0.6460 ± 0.0097
HIDDEN_DIM = 64
NUM_LAYERS = 2
NUM_HEADS = 1
DROPOUT = 0.2
LR = 0.001
OUTPUT_DIM = 3  # number of classes (-1, 0, 1)
NUM_EPOCHS = 20
BATCH_SIZE = 32


# HIDDEN_DIM = 64
# NUM_LAYERS = 1
# NUM_HEADS = 1
# DROPOUT = 0.3
# LR = 1e-4
# OUTPUT_DIM = 3  # number of classes (-1, 0, 1)
# NUM_EPOCHS = 200
# BATCH_SIZE = 16


class TwoHeadTransformerModel(nn.Module):
    def __init__(self, mcg_input_dim, atac_input_dim, hidden_dim, output_dim, num_layers=2, num_heads=1, dropout=0.1):
        super(TwoHeadTransformerModel, self).__init__()
        self.mcg_embedding = nn.Linear(mcg_input_dim, hidden_dim)
        self.atac_embedding = nn.Linear(atac_input_dim, hidden_dim)
        
        # TODO: may need to use tanh in attention instead of softmax
        encoder_layers = nn.TransformerEncoderLayer(hidden_dim, num_heads, dim_feedforward=hidden_dim*2, dropout=dropout, batch_first=True, norm_first=True)
        self.mcg_transformer = nn.TransformerEncoder(encoder_layers, num_layers)
        self.atac_transformer = nn.TransformerEncoder(encoder_layers, num_layers)
        
        self.classifier = nn.Linear(hidden_dim * 2, output_dim)
        self.hidden_dim = hidden_dim

    def forward(self, mcg_x, mcg_mask, atac_x, atac_mask):
        mcg_x = self.mcg_embedding(mcg_x)
        atac_x = self.atac_embedding(atac_x)
        
        mcg_x = self.mcg_transformer(mcg_x, src_key_padding_mask=~mcg_mask.bool())
        atac_x = self.atac_transformer(atac_x, src_key_padding_mask=~atac_mask.bool())
        
        # Global average pooling
        mcg_x = mcg_x.mean(dim=1)
        atac_x = atac_x.mean(dim=1)
        
        # Concatenate MCG and ATAC embeddings
        combined_x = torch.cat((mcg_x, atac_x), dim=1)
        
        output = self.classifier(combined_x)
        return output

class CombinedGeneDataset(Dataset):
    def __init__(self, mcg_data, atac_data, labels):
        self.mcg_data = mcg_data
        self.atac_data = atac_data
        self.labels = labels
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        mcg_gene_data = torch.FloatTensor(self.mcg_data[idx])
        atac_gene_data = torch.FloatTensor(self.atac_data[idx])
        label = torch.LongTensor([self.labels[idx] + 1])  # Add 1 to shift labels to 0, 1, 2
        mcg_mask = torch.ones(len(mcg_gene_data))
        atac_mask = torch.ones(len(atac_gene_data))
        return mcg_gene_data, atac_gene_data, label, mcg_mask, atac_mask

def combined_collate_fn(batch):
    batch.sort(key=lambda x: len(x[0]), reverse=True)
    mcg_sequences, atac_sequences, labels, mcg_masks, atac_masks = zip(*batch)
    
    mcg_lengths = [len(seq) for seq in mcg_sequences]
    atac_lengths = [len(seq) for seq in atac_sequences]
    mcg_max_len = max(mcg_lengths)
    atac_max_len = max(atac_lengths)
    
    padded_mcg_seqs = torch.zeros(len(mcg_sequences), mcg_max_len, mcg_sequences[0].size(1))
    padded_atac_seqs = torch.zeros(len(atac_sequences), atac_max_len, atac_sequences[0].size(1))
    padded_mcg_masks = torch.zeros(len(mcg_sequences), mcg_max_len)
    padded_atac_masks = torch.zeros(len(atac_sequences), atac_max_len)
    
    for i, (mcg_seq, atac_seq, mcg_length, atac_length) in enumerate(zip(mcg_sequences, atac_sequences, mcg_lengths, atac_lengths)):
        padded_mcg_seqs[i, :mcg_length] = mcg_seq
        padded_atac_seqs[i, :atac_length] = atac_seq
        padded_mcg_masks[i, :mcg_length] = 1
        padded_atac_masks[i, :atac_length] = 1
    
    return padded_mcg_seqs, padded_atac_seqs, torch.cat(labels), padded_mcg_masks, padded_atac_masks

def train_combined_model(X_train_mcg, X_train_atac, y_train, X_test_mcg, X_test_atac, y_test, exp_name, fold_idx):
    #wandb.init(project='gene', group=exp_name, name=f'fold-{fold_idx}')
    mcg_input_dim = len(X_train_mcg[0][0])
    atac_input_dim = len(X_train_atac[0][0])
    
    train_dataset = CombinedGeneDataset(X_train_mcg, X_train_atac, y_train)
    test_dataset = CombinedGeneDataset(X_test_mcg, X_test_atac, y_test)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=combined_collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=combined_collate_fn)

    model = TwoHeadTransformerModel(mcg_input_dim, atac_input_dim, HIDDEN_DIM, OUTPUT_DIM, num_layers=NUM_LAYERS, num_heads=NUM_HEADS, dropout=DROPOUT)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LR)

    # lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=0.0001)
    # Create the OneCycleLR scheduler
    lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR, total_steps=NUM_EPOCHS,
                              pct_start=0.8, anneal_strategy='cos',
                              cycle_momentum=False, div_factor=5.0,
                              final_div_factor=10.0)

    for epoch in tqdm(range(NUM_EPOCHS)):
        model.train()
        total_loss = 0
        train_correct = 0
        train_total = 0
        for mcg_x, atac_x, batch_y, mcg_mask, atac_mask in train_loader:
            optimizer.zero_grad()
            outputs = model(mcg_x, mcg_mask, atac_x, atac_mask)
            loss = criterion(outputs, batch_y.squeeze())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            train_correct += (outputs.argmax(dim=1) == batch_y.squeeze()).sum().item()
            train_total += batch_y.size(0)
        lr_scheduler.step()

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for mcg_x, atac_x, batch_y, mcg_mask, atac_mask in test_loader:
                outputs = model(mcg_x, mcg_mask, atac_x, atac_mask)
                _, predicted = torch.max(outputs.data, 1)
                total += batch_y.size(0)
                correct += (predicted == batch_y.squeeze()).sum().item()
        
        accuracy = correct / total
        print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Train Loss: {total_loss/len(train_loader):.4f}, Train Accuracy: {train_correct/train_total:.4f}, Test Accuracy: {accuracy:.4f}')
        #wandb.log({'epoch': epoch, 'LR': optimizer.param_groups[0]['lr'], 'train_loss': total_loss/len(train_loader), 'train_accuracy': train_correct/train_total, 'test_accuracy': accuracy})

    model.eval()
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for mcg_x, atac_x, batch_y, mcg_mask, atac_mask in test_loader:
            outputs = model(mcg_x, mcg_mask, atac_x, atac_mask)
            _, predicted = torch.max(outputs.data, 1)
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(batch_y.cpu().numpy())

    final_accuracy = sum(np.array(all_predictions) == np.array(all_labels).squeeze()) / len(all_labels)
    print(f'Final Test Accuracy: {final_accuracy:.4f}')
    return final_accuracy

In [13]:
import time
exp_name = f'two-head-{time.strftime("%Y%m%d-%H%M%S")}'

# wandb.config.update({
#     'hidden_dim': HIDDEN_DIM,
#     'num_layers': NUM_LAYERS,
#     'num_heads': NUM_HEADS,
#     'dropout': DROPOUT,
#     'lr': LR,
#     'output_dim': OUTPUT_DIM,
#     'num_epochs': NUM_EPOCHS
# })


kf = KFold(n_splits=5, shuffle=True, random_state=25)
accuracies = []
for i, (train_index, test_index) in enumerate(kf.split(X_balanced['mcg'])):
    X_train_mcg, X_test_mcg = [X_balanced['mcg'][i] for i in train_index], [X_balanced['mcg'][i] for i in test_index]
    X_train_atac, X_test_atac = [X_balanced['atac'][i] for i in train_index], [X_balanced['atac'][i] for i in test_index]
    y_train, y_test = [y_balanced[i] for i in train_index], [y_balanced[i] for i in test_index]
    
    X_train_mcg_normalized, X_test_mcg_normalized = normalize_features(X_train_mcg, X_test_mcg)
    X_train_atac_normalized, X_test_atac_normalized = normalize_features(X_train_atac, X_test_atac)
    
    accuracies.append(train_combined_model(X_train_mcg_normalized, X_train_atac_normalized, y_train, 
                                           X_test_mcg_normalized, X_test_atac_normalized, y_test, exp_name=exp_name, fold_idx=i))

print(f'Mean Accuracy: {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}')

  5%|▌         | 1/20 [00:02<00:42,  2.22s/it]

Epoch [1/20], Train Loss: 0.9886, Train Accuracy: 0.5370, Test Accuracy: 0.6066


 10%|█         | 2/20 [00:04<00:37,  2.09s/it]

Epoch [2/20], Train Loss: 0.9039, Train Accuracy: 0.5844, Test Accuracy: 0.6431


 15%|█▌        | 3/20 [00:06<00:34,  2.04s/it]

Epoch [3/20], Train Loss: 0.7940, Train Accuracy: 0.6380, Test Accuracy: 0.6478


 20%|██        | 4/20 [00:08<00:32,  2.03s/it]

Epoch [4/20], Train Loss: 0.7994, Train Accuracy: 0.6415, Test Accuracy: 0.6690


 25%|██▌       | 5/20 [00:10<00:30,  2.02s/it]

Epoch [5/20], Train Loss: 0.7890, Train Accuracy: 0.6492, Test Accuracy: 0.6431


 30%|███       | 6/20 [00:12<00:28,  2.01s/it]

Epoch [6/20], Train Loss: 0.7765, Train Accuracy: 0.6465, Test Accuracy: 0.6243


 35%|███▌      | 7/20 [00:14<00:26,  2.02s/it]

Epoch [7/20], Train Loss: 0.7839, Train Accuracy: 0.6477, Test Accuracy: 0.6596


 40%|████      | 8/20 [00:16<00:24,  2.02s/it]

Epoch [8/20], Train Loss: 0.7716, Train Accuracy: 0.6524, Test Accuracy: 0.6820


 45%|████▌     | 9/20 [00:18<00:22,  2.01s/it]

Epoch [9/20], Train Loss: 0.7795, Train Accuracy: 0.6468, Test Accuracy: 0.6078


 50%|█████     | 10/20 [00:20<00:20,  2.06s/it]

Epoch [10/20], Train Loss: 0.7726, Train Accuracy: 0.6471, Test Accuracy: 0.6726


 55%|█████▌    | 11/20 [00:22<00:19,  2.12s/it]

Epoch [11/20], Train Loss: 0.7616, Train Accuracy: 0.6616, Test Accuracy: 0.6549


 60%|██████    | 12/20 [00:24<00:16,  2.09s/it]

Epoch [12/20], Train Loss: 0.7778, Train Accuracy: 0.6492, Test Accuracy: 0.6620


 65%|██████▌   | 13/20 [00:26<00:14,  2.08s/it]

Epoch [13/20], Train Loss: 0.7571, Train Accuracy: 0.6648, Test Accuracy: 0.6360


 70%|███████   | 14/20 [00:28<00:12,  2.06s/it]

Epoch [14/20], Train Loss: 0.7512, Train Accuracy: 0.6624, Test Accuracy: 0.6549


 75%|███████▌  | 15/20 [00:30<00:10,  2.06s/it]

Epoch [15/20], Train Loss: 0.7491, Train Accuracy: 0.6598, Test Accuracy: 0.6773


 80%|████████  | 16/20 [00:32<00:08,  2.05s/it]

Epoch [16/20], Train Loss: 0.7582, Train Accuracy: 0.6719, Test Accuracy: 0.6408


 85%|████████▌ | 17/20 [00:34<00:06,  2.05s/it]

Epoch [17/20], Train Loss: 0.7424, Train Accuracy: 0.6633, Test Accuracy: 0.6278


 90%|█████████ | 18/20 [00:36<00:04,  2.05s/it]

Epoch [18/20], Train Loss: 0.7288, Train Accuracy: 0.6748, Test Accuracy: 0.6832


 95%|█████████▌| 19/20 [00:38<00:02,  2.03s/it]

Epoch [19/20], Train Loss: 0.7123, Train Accuracy: 0.6816, Test Accuracy: 0.6843


100%|██████████| 20/20 [00:40<00:00,  2.05s/it]

Epoch [20/20], Train Loss: 0.6977, Train Accuracy: 0.6881, Test Accuracy: 0.6890
Final Test Accuracy: 0.6890



  5%|▌         | 1/20 [00:02<00:38,  2.03s/it]

Epoch [1/20], Train Loss: 0.9487, Train Accuracy: 0.5658, Test Accuracy: 0.6125


 10%|█         | 2/20 [00:04<00:36,  2.01s/it]

Epoch [2/20], Train Loss: 0.8031, Train Accuracy: 0.6409, Test Accuracy: 0.6525


 15%|█▌        | 3/20 [00:06<00:34,  2.01s/it]

Epoch [3/20], Train Loss: 0.7924, Train Accuracy: 0.6554, Test Accuracy: 0.6525


 20%|██        | 4/20 [00:08<00:32,  2.04s/it]

Epoch [4/20], Train Loss: 0.7779, Train Accuracy: 0.6486, Test Accuracy: 0.6478


 25%|██▌       | 5/20 [00:10<00:30,  2.02s/it]

Epoch [5/20], Train Loss: 0.7807, Train Accuracy: 0.6607, Test Accuracy: 0.6572


 30%|███       | 6/20 [00:12<00:28,  2.03s/it]

Epoch [6/20], Train Loss: 0.7805, Train Accuracy: 0.6524, Test Accuracy: 0.6608


 35%|███▌      | 7/20 [00:14<00:26,  2.03s/it]

Epoch [7/20], Train Loss: 0.7925, Train Accuracy: 0.6489, Test Accuracy: 0.6596


 40%|████      | 8/20 [00:16<00:24,  2.04s/it]

Epoch [8/20], Train Loss: 0.7857, Train Accuracy: 0.6504, Test Accuracy: 0.6584


 45%|████▌     | 9/20 [00:18<00:22,  2.05s/it]

Epoch [9/20], Train Loss: 0.7693, Train Accuracy: 0.6521, Test Accuracy: 0.6549


 50%|█████     | 10/20 [00:20<00:20,  2.03s/it]

Epoch [10/20], Train Loss: 0.7604, Train Accuracy: 0.6630, Test Accuracy: 0.6561


 55%|█████▌    | 11/20 [00:22<00:18,  2.04s/it]

Epoch [11/20], Train Loss: 0.7729, Train Accuracy: 0.6533, Test Accuracy: 0.6243


 60%|██████    | 12/20 [00:24<00:16,  2.04s/it]

Epoch [12/20], Train Loss: 0.7619, Train Accuracy: 0.6672, Test Accuracy: 0.6431


 65%|██████▌   | 13/20 [00:26<00:14,  2.04s/it]

Epoch [13/20], Train Loss: 0.7568, Train Accuracy: 0.6636, Test Accuracy: 0.6207


 70%|███████   | 14/20 [00:28<00:12,  2.05s/it]

Epoch [14/20], Train Loss: 0.7629, Train Accuracy: 0.6589, Test Accuracy: 0.5901


 75%|███████▌  | 15/20 [00:30<00:10,  2.04s/it]

Epoch [15/20], Train Loss: 0.7705, Train Accuracy: 0.6636, Test Accuracy: 0.6631


 80%|████████  | 16/20 [00:32<00:08,  2.05s/it]

Epoch [16/20], Train Loss: 0.7473, Train Accuracy: 0.6689, Test Accuracy: 0.6525


 85%|████████▌ | 17/20 [00:34<00:06,  2.04s/it]

Epoch [17/20], Train Loss: 0.7559, Train Accuracy: 0.6669, Test Accuracy: 0.6431


 90%|█████████ | 18/20 [00:37<00:04,  2.15s/it]

Epoch [18/20], Train Loss: 0.7307, Train Accuracy: 0.6772, Test Accuracy: 0.6372


 95%|█████████▌| 19/20 [00:39<00:02,  2.18s/it]

Epoch [19/20], Train Loss: 0.7127, Train Accuracy: 0.6878, Test Accuracy: 0.6655


100%|██████████| 20/20 [00:41<00:00,  2.07s/it]

Epoch [20/20], Train Loss: 0.7143, Train Accuracy: 0.6887, Test Accuracy: 0.6690
Final Test Accuracy: 0.6690



  5%|▌         | 1/20 [00:02<00:39,  2.07s/it]

Epoch [1/20], Train Loss: 0.9722, Train Accuracy: 0.5552, Test Accuracy: 0.5795


 10%|█         | 2/20 [00:04<00:36,  2.05s/it]

Epoch [2/20], Train Loss: 0.8669, Train Accuracy: 0.6065, Test Accuracy: 0.6137


 15%|█▌        | 3/20 [00:06<00:34,  2.05s/it]

Epoch [3/20], Train Loss: 0.8085, Train Accuracy: 0.6365, Test Accuracy: 0.6254


 20%|██        | 4/20 [00:08<00:32,  2.06s/it]

Epoch [4/20], Train Loss: 0.7863, Train Accuracy: 0.6374, Test Accuracy: 0.6502


 25%|██▌       | 5/20 [00:10<00:30,  2.06s/it]

Epoch [5/20], Train Loss: 0.7909, Train Accuracy: 0.6501, Test Accuracy: 0.6419


 30%|███       | 6/20 [00:12<00:29,  2.08s/it]

Epoch [6/20], Train Loss: 0.7729, Train Accuracy: 0.6471, Test Accuracy: 0.6631


 35%|███▌      | 7/20 [00:14<00:26,  2.06s/it]

Epoch [7/20], Train Loss: 0.7680, Train Accuracy: 0.6577, Test Accuracy: 0.6514


 40%|████      | 8/20 [00:16<00:24,  2.05s/it]

Epoch [8/20], Train Loss: 0.7770, Train Accuracy: 0.6539, Test Accuracy: 0.6502


 45%|████▌     | 9/20 [00:18<00:22,  2.05s/it]

Epoch [9/20], Train Loss: 0.7699, Train Accuracy: 0.6454, Test Accuracy: 0.6525


 50%|█████     | 10/20 [00:20<00:20,  2.05s/it]

Epoch [10/20], Train Loss: 0.7878, Train Accuracy: 0.6462, Test Accuracy: 0.6408


 55%|█████▌    | 11/20 [00:22<00:18,  2.04s/it]

Epoch [11/20], Train Loss: 0.7703, Train Accuracy: 0.6530, Test Accuracy: 0.6514


 60%|██████    | 12/20 [00:24<00:16,  2.04s/it]

Epoch [12/20], Train Loss: 0.7819, Train Accuracy: 0.6504, Test Accuracy: 0.6631


 65%|██████▌   | 13/20 [00:26<00:14,  2.04s/it]

Epoch [13/20], Train Loss: 0.7648, Train Accuracy: 0.6695, Test Accuracy: 0.6384


 70%|███████   | 14/20 [00:28<00:12,  2.03s/it]

Epoch [14/20], Train Loss: 0.7852, Train Accuracy: 0.6545, Test Accuracy: 0.6667


 75%|███████▌  | 15/20 [00:30<00:10,  2.04s/it]

Epoch [15/20], Train Loss: 0.7806, Train Accuracy: 0.6486, Test Accuracy: 0.6349


 80%|████████  | 16/20 [00:32<00:08,  2.05s/it]

Epoch [16/20], Train Loss: 0.7581, Train Accuracy: 0.6619, Test Accuracy: 0.6631


 85%|████████▌ | 17/20 [00:34<00:06,  2.05s/it]

Epoch [17/20], Train Loss: 0.7498, Train Accuracy: 0.6701, Test Accuracy: 0.6726


 90%|█████████ | 18/20 [00:36<00:04,  2.08s/it]

Epoch [18/20], Train Loss: 0.7407, Train Accuracy: 0.6707, Test Accuracy: 0.6667


 95%|█████████▌| 19/20 [00:39<00:02,  2.08s/it]

Epoch [19/20], Train Loss: 0.7183, Train Accuracy: 0.6898, Test Accuracy: 0.6702


100%|██████████| 20/20 [00:41<00:00,  2.06s/it]

Epoch [20/20], Train Loss: 0.7164, Train Accuracy: 0.6854, Test Accuracy: 0.6784
Final Test Accuracy: 0.6784



  5%|▌         | 1/20 [00:02<00:38,  2.02s/it]

Epoch [1/20], Train Loss: 0.9427, Train Accuracy: 0.5623, Test Accuracy: 0.5889


 10%|█         | 2/20 [00:04<00:36,  2.05s/it]

Epoch [2/20], Train Loss: 0.8198, Train Accuracy: 0.6265, Test Accuracy: 0.6502


 15%|█▌        | 3/20 [00:06<00:35,  2.07s/it]

Epoch [3/20], Train Loss: 0.7833, Train Accuracy: 0.6465, Test Accuracy: 0.6561


 20%|██        | 4/20 [00:08<00:33,  2.06s/it]

Epoch [4/20], Train Loss: 0.7774, Train Accuracy: 0.6577, Test Accuracy: 0.6384


 25%|██▌       | 5/20 [00:10<00:31,  2.07s/it]

Epoch [5/20], Train Loss: 0.7879, Train Accuracy: 0.6513, Test Accuracy: 0.6643


 30%|███       | 6/20 [00:12<00:29,  2.08s/it]

Epoch [6/20], Train Loss: 0.7716, Train Accuracy: 0.6583, Test Accuracy: 0.6584


 35%|███▌      | 7/20 [00:14<00:28,  2.15s/it]

Epoch [7/20], Train Loss: 0.7809, Train Accuracy: 0.6513, Test Accuracy: 0.6278


 40%|████      | 8/20 [00:16<00:26,  2.18s/it]

Epoch [8/20], Train Loss: 0.7631, Train Accuracy: 0.6589, Test Accuracy: 0.6631


 45%|████▌     | 9/20 [00:19<00:24,  2.19s/it]

Epoch [9/20], Train Loss: 0.7580, Train Accuracy: 0.6639, Test Accuracy: 0.6325


 50%|█████     | 10/20 [00:21<00:21,  2.18s/it]

Epoch [10/20], Train Loss: 0.7699, Train Accuracy: 0.6571, Test Accuracy: 0.6443


 55%|█████▌    | 11/20 [00:23<00:19,  2.16s/it]

Epoch [11/20], Train Loss: 0.7623, Train Accuracy: 0.6651, Test Accuracy: 0.6302


 60%|██████    | 12/20 [00:25<00:17,  2.14s/it]

Epoch [12/20], Train Loss: 0.7480, Train Accuracy: 0.6639, Test Accuracy: 0.6643


 65%|██████▌   | 13/20 [00:27<00:14,  2.12s/it]

Epoch [13/20], Train Loss: 0.7568, Train Accuracy: 0.6601, Test Accuracy: 0.6643


 70%|███████   | 14/20 [00:29<00:12,  2.11s/it]

Epoch [14/20], Train Loss: 0.7458, Train Accuracy: 0.6654, Test Accuracy: 0.6643


 75%|███████▌  | 15/20 [00:31<00:10,  2.11s/it]

Epoch [15/20], Train Loss: 0.7425, Train Accuracy: 0.6680, Test Accuracy: 0.6466


 80%|████████  | 16/20 [00:33<00:08,  2.12s/it]

Epoch [16/20], Train Loss: 0.7776, Train Accuracy: 0.6445, Test Accuracy: 0.6737


 85%|████████▌ | 17/20 [00:36<00:06,  2.12s/it]

Epoch [17/20], Train Loss: 0.7327, Train Accuracy: 0.6707, Test Accuracy: 0.6631


 90%|█████████ | 18/20 [00:38<00:04,  2.11s/it]

Epoch [18/20], Train Loss: 0.7145, Train Accuracy: 0.6842, Test Accuracy: 0.6655


 95%|█████████▌| 19/20 [00:40<00:02,  2.11s/it]

Epoch [19/20], Train Loss: 0.7003, Train Accuracy: 0.6943, Test Accuracy: 0.6631


100%|██████████| 20/20 [00:42<00:00,  2.12s/it]

Epoch [20/20], Train Loss: 0.7026, Train Accuracy: 0.6931, Test Accuracy: 0.6667
Final Test Accuracy: 0.6667



  5%|▌         | 1/20 [00:02<00:40,  2.12s/it]

Epoch [1/20], Train Loss: 0.9529, Train Accuracy: 0.5663, Test Accuracy: 0.5708


 10%|█         | 2/20 [00:04<00:37,  2.11s/it]

Epoch [2/20], Train Loss: 0.8328, Train Accuracy: 0.6272, Test Accuracy: 0.6333


 15%|█▌        | 3/20 [00:06<00:35,  2.11s/it]

Epoch [3/20], Train Loss: 0.7958, Train Accuracy: 0.6496, Test Accuracy: 0.6427


 20%|██        | 4/20 [00:08<00:33,  2.10s/it]

Epoch [4/20], Train Loss: 0.7855, Train Accuracy: 0.6572, Test Accuracy: 0.6509


 25%|██▌       | 5/20 [00:10<00:31,  2.10s/it]

Epoch [5/20], Train Loss: 0.7605, Train Accuracy: 0.6646, Test Accuracy: 0.6509


 30%|███       | 6/20 [00:12<00:29,  2.10s/it]

Epoch [6/20], Train Loss: 0.7757, Train Accuracy: 0.6558, Test Accuracy: 0.6392


 35%|███▌      | 7/20 [00:14<00:27,  2.11s/it]

Epoch [7/20], Train Loss: 0.7749, Train Accuracy: 0.6578, Test Accuracy: 0.6321


 40%|████      | 8/20 [00:16<00:25,  2.11s/it]

Epoch [8/20], Train Loss: 0.7826, Train Accuracy: 0.6578, Test Accuracy: 0.6450


 45%|████▌     | 9/20 [00:18<00:23,  2.11s/it]

Epoch [9/20], Train Loss: 0.7604, Train Accuracy: 0.6625, Test Accuracy: 0.6462


 50%|█████     | 10/20 [00:21<00:21,  2.12s/it]

Epoch [10/20], Train Loss: 0.7754, Train Accuracy: 0.6543, Test Accuracy: 0.6568


 55%|█████▌    | 11/20 [00:23<00:19,  2.12s/it]

Epoch [11/20], Train Loss: 0.7593, Train Accuracy: 0.6658, Test Accuracy: 0.5802


 60%|██████    | 12/20 [00:25<00:16,  2.12s/it]

Epoch [12/20], Train Loss: 0.7570, Train Accuracy: 0.6720, Test Accuracy: 0.6616


 65%|██████▌   | 13/20 [00:27<00:14,  2.13s/it]

Epoch [13/20], Train Loss: 0.7577, Train Accuracy: 0.6640, Test Accuracy: 0.6580


 70%|███████   | 14/20 [00:29<00:12,  2.14s/it]

Epoch [14/20], Train Loss: 0.7520, Train Accuracy: 0.6661, Test Accuracy: 0.6309


 75%|███████▌  | 15/20 [00:32<00:10,  2.20s/it]

Epoch [15/20], Train Loss: 0.7580, Train Accuracy: 0.6622, Test Accuracy: 0.6533


 80%|████████  | 16/20 [00:34<00:09,  2.29s/it]

Epoch [16/20], Train Loss: 0.7495, Train Accuracy: 0.6684, Test Accuracy: 0.6450


 85%|████████▌ | 17/20 [00:36<00:06,  2.27s/it]

Epoch [17/20], Train Loss: 0.7311, Train Accuracy: 0.6752, Test Accuracy: 0.6297


 90%|█████████ | 18/20 [00:38<00:04,  2.21s/it]

Epoch [18/20], Train Loss: 0.7421, Train Accuracy: 0.6776, Test Accuracy: 0.6533


 95%|█████████▌| 19/20 [00:40<00:02,  2.18s/it]

Epoch [19/20], Train Loss: 0.7151, Train Accuracy: 0.6820, Test Accuracy: 0.6639


100%|██████████| 20/20 [00:43<00:00,  2.15s/it]

Epoch [20/20], Train Loss: 0.7091, Train Accuracy: 0.6896, Test Accuracy: 0.6733
Final Test Accuracy: 0.6733
Mean Accuracy: 0.6753 ± 0.0080



