In [1]:
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.model_selection import KFold

In [2]:
import pandas as pd

In [3]:
# hic = pd.read_csv('data/Oligo_NN.loop_gene.csv')
#  # drop 'strand', 'gene_type'
# hic = hic.drop(['gene_id','strand','gene_type'], axis=1)
# # for Qanova	Eanova	Tanova, keep only 2 decimal places
# hic['Qanova'] = hic['Qanova'].apply(lambda x: round(x,1))
# hic['Eanova'] = hic['Eanova'].apply(lambda x: round(x,1))
# hic['Tanova'] = hic['Tanova'].apply(lambda x: round(x,1))
# hic
# #
# hic.head()
# hic.to_csv('data/Oligo_NN.loop_gene.csv')

In [4]:
from data_loader import load_data, get_balanced_data, normalize_features

In [5]:
data = load_data()
X_balanced, y_balanced = get_balanced_data(data)
print(len(X_balanced['mcg']), len(X_balanced['atac']))

Processed mcg data
Processed genebody data
Processed atac data
Processed hic data
zero: 9941, non-zero: 3183
4244 4244


In [6]:
# Mean Accuracy: 0.6460 ± 0.0097
HIDDEN_DIM = 64
NUM_LAYERS = 2
NUM_HEADS = 1
DROPOUT = 0.2
LR = 0.001
OUTPUT_DIM = 3  # number of classes (-1, 0, 1)
NUM_EPOCHS = 20
BATCH_SIZE = 32


# HIDDEN_DIM = 64
# NUM_LAYERS = 1
# NUM_HEADS = 1
# DROPOUT = 0.3
# LR = 1e-4
# OUTPUT_DIM = 3  # number of classes (-1, 0, 1)
# NUM_EPOCHS = 200
# BATCH_SIZE = 16


class TwoHeadTransformerModel(nn.Module):
    def __init__(self, mcg_input_dim, atac_input_dim, hidden_dim, output_dim, num_layers=2, num_heads=1, dropout=0.1):
        super(TwoHeadTransformerModel, self).__init__()
        self.mcg_embedding = nn.Linear(mcg_input_dim, hidden_dim)
        self.atac_embedding = nn.Linear(atac_input_dim, hidden_dim)
        
        # TODO: may need to use tanh in attention instead of softmax
        encoder_layers = nn.TransformerEncoderLayer(hidden_dim, num_heads, dim_feedforward=hidden_dim*2, dropout=dropout, batch_first=True, norm_first=True)
        self.mcg_transformer = nn.TransformerEncoder(encoder_layers, num_layers)
        self.atac_transformer = nn.TransformerEncoder(encoder_layers, num_layers)
        
        self.classifier = nn.Linear(hidden_dim * 2, output_dim)
        self.hidden_dim = hidden_dim

    def forward(self, mcg_x, mcg_mask, atac_x, atac_mask):
        mcg_x = self.mcg_embedding(mcg_x)
        atac_x = self.atac_embedding(atac_x)
        
        mcg_x = self.mcg_transformer(mcg_x, src_key_padding_mask=~mcg_mask.bool())
        atac_x = self.atac_transformer(atac_x, src_key_padding_mask=~atac_mask.bool())
        
        # Global average pooling
        mcg_x = mcg_x.mean(dim=1)
        atac_x = atac_x.mean(dim=1)
        
        # Concatenate MCG and ATAC embeddings
        combined_x = torch.cat((mcg_x, atac_x), dim=1)
        
        output = self.classifier(combined_x)
        return output

class CombinedGeneDataset(Dataset):
    def __init__(self, mcg_data, atac_data, labels):
        self.mcg_data = mcg_data
        self.atac_data = atac_data
        self.labels = labels
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        mcg_gene_data = torch.FloatTensor(self.mcg_data[idx])
        atac_gene_data = torch.FloatTensor(self.atac_data[idx])
        label = torch.LongTensor([self.labels[idx] + 1])  # Add 1 to shift labels to 0, 1, 2
        mcg_mask = torch.ones(len(mcg_gene_data))
        atac_mask = torch.ones(len(atac_gene_data))
        return mcg_gene_data, atac_gene_data, label, mcg_mask, atac_mask

def combined_collate_fn(batch):
    batch.sort(key=lambda x: len(x[0]), reverse=True)
    mcg_sequences, atac_sequences, labels, mcg_masks, atac_masks = zip(*batch)
    
    mcg_lengths = [len(seq) for seq in mcg_sequences]
    atac_lengths = [len(seq) for seq in atac_sequences]
    mcg_max_len = max(mcg_lengths)
    atac_max_len = max(atac_lengths)
    
    padded_mcg_seqs = torch.zeros(len(mcg_sequences), mcg_max_len, mcg_sequences[0].size(1))
    padded_atac_seqs = torch.zeros(len(atac_sequences), atac_max_len, atac_sequences[0].size(1))
    padded_mcg_masks = torch.zeros(len(mcg_sequences), mcg_max_len)
    padded_atac_masks = torch.zeros(len(atac_sequences), atac_max_len)
    
    for i, (mcg_seq, atac_seq, mcg_length, atac_length) in enumerate(zip(mcg_sequences, atac_sequences, mcg_lengths, atac_lengths)):
        padded_mcg_seqs[i, :mcg_length] = mcg_seq
        padded_atac_seqs[i, :atac_length] = atac_seq
        padded_mcg_masks[i, :mcg_length] = 1
        padded_atac_masks[i, :atac_length] = 1
    
    return padded_mcg_seqs, padded_atac_seqs, torch.cat(labels), padded_mcg_masks, padded_atac_masks

def train_combined_model(X_train_mcg, X_train_atac, y_train, X_test_mcg, X_test_atac, y_test, exp_name, fold_idx):
    #wandb.init(project='gene', group=exp_name, name=f'fold-{fold_idx}')
    mcg_input_dim = len(X_train_mcg[0][0])
    atac_input_dim = len(X_train_atac[0][0])
    
    train_dataset = CombinedGeneDataset(X_train_mcg, X_train_atac, y_train)
    test_dataset = CombinedGeneDataset(X_test_mcg, X_test_atac, y_test)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=combined_collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=combined_collate_fn)

    model = TwoHeadTransformerModel(mcg_input_dim, atac_input_dim, HIDDEN_DIM, OUTPUT_DIM, num_layers=NUM_LAYERS, num_heads=NUM_HEADS, dropout=DROPOUT)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LR)

    # lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=0.0001)
    # Create the OneCycleLR scheduler
    lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR, total_steps=NUM_EPOCHS,
                              pct_start=0.8, anneal_strategy='cos',
                              cycle_momentum=False, div_factor=5.0,
                              final_div_factor=10.0)

    for epoch in tqdm(range(NUM_EPOCHS)):
        model.train()
        total_loss = 0
        train_correct = 0
        train_total = 0
        for mcg_x, atac_x, batch_y, mcg_mask, atac_mask in train_loader:
            optimizer.zero_grad()
            outputs = model(mcg_x, mcg_mask, atac_x, atac_mask)
            loss = criterion(outputs, batch_y.squeeze())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            train_correct += (outputs.argmax(dim=1) == batch_y.squeeze()).sum().item()
            train_total += batch_y.size(0)
        lr_scheduler.step()

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for mcg_x, atac_x, batch_y, mcg_mask, atac_mask in test_loader:
                outputs = model(mcg_x, mcg_mask, atac_x, atac_mask)
                _, predicted = torch.max(outputs.data, 1)
                total += batch_y.size(0)
                correct += (predicted == batch_y.squeeze()).sum().item()
        
        accuracy = correct / total
        print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Train Loss: {total_loss/len(train_loader):.4f}, Train Accuracy: {train_correct/train_total:.4f}, Test Accuracy: {accuracy:.4f}')
        #wandb.log({'epoch': epoch, 'LR': optimizer.param_groups[0]['lr'], 'train_loss': total_loss/len(train_loader), 'train_accuracy': train_correct/train_total, 'test_accuracy': accuracy})

    model.eval()
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for mcg_x, atac_x, batch_y, mcg_mask, atac_mask in test_loader:
            outputs = model(mcg_x, mcg_mask, atac_x, atac_mask)
            _, predicted = torch.max(outputs.data, 1)
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(batch_y.cpu().numpy())

    final_accuracy = sum(np.array(all_predictions) == np.array(all_labels).squeeze()) / len(all_labels)
    print(f'Final Test Accuracy: {final_accuracy:.4f}')
    return final_accuracy

In [7]:
import time
exp_name = f'two-head-{time.strftime("%Y%m%d-%H%M%S")}'

# wandb.config.update({
#     'hidden_dim': HIDDEN_DIM,
#     'num_layers': NUM_LAYERS,
#     'num_heads': NUM_HEADS,
#     'dropout': DROPOUT,
#     'lr': LR,
#     'output_dim': OUTPUT_DIM,
#     'num_epochs': NUM_EPOCHS
# })


kf = KFold(n_splits=5, shuffle=True, random_state=25)
accuracies = []
for i, (train_index, test_index) in enumerate(kf.split(X_balanced['mcg'])):
    X_train_mcg, X_test_mcg = [X_balanced['mcg'][i] for i in train_index], [X_balanced['mcg'][i] for i in test_index]
    X_train_atac, X_test_atac = [X_balanced['atac'][i] for i in train_index], [X_balanced['atac'][i] for i in test_index]
    y_train, y_test = [y_balanced[i] for i in train_index], [y_balanced[i] for i in test_index]
    
    X_train_mcg_normalized, X_test_mcg_normalized = normalize_features(X_train_mcg, X_test_mcg)
    X_train_atac_normalized, X_test_atac_normalized = normalize_features(X_train_atac, X_test_atac)
    
    accuracies.append(train_combined_model(X_train_mcg_normalized, X_train_atac_normalized, y_train, 
                                           X_test_mcg_normalized, X_test_atac_normalized, y_test, exp_name=exp_name, fold_idx=i))

print(f'Mean Accuracy: {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}')

  5%|▌         | 1/20 [00:03<01:13,  3.87s/it]

Epoch [1/20], Train Loss: 0.9441, Train Accuracy: 0.5567, Test Accuracy: 0.6219


 10%|█         | 2/20 [00:07<01:08,  3.83s/it]

Epoch [2/20], Train Loss: 0.8203, Train Accuracy: 0.6253, Test Accuracy: 0.6596


 15%|█▌        | 3/20 [00:11<01:03,  3.76s/it]

Epoch [3/20], Train Loss: 0.7789, Train Accuracy: 0.6421, Test Accuracy: 0.6584


 20%|██        | 4/20 [00:14<00:59,  3.70s/it]

Epoch [4/20], Train Loss: 0.7782, Train Accuracy: 0.6515, Test Accuracy: 0.6514


 25%|██▌       | 5/20 [00:18<00:56,  3.74s/it]

Epoch [5/20], Train Loss: 0.7971, Train Accuracy: 0.6359, Test Accuracy: 0.6678


 30%|███       | 6/20 [00:22<00:50,  3.58s/it]

Epoch [6/20], Train Loss: 0.7612, Train Accuracy: 0.6536, Test Accuracy: 0.6655


 35%|███▌      | 7/20 [00:25<00:45,  3.50s/it]

Epoch [7/20], Train Loss: 0.7560, Train Accuracy: 0.6619, Test Accuracy: 0.6867


 40%|████      | 8/20 [00:28<00:41,  3.45s/it]

Epoch [8/20], Train Loss: 0.7757, Train Accuracy: 0.6501, Test Accuracy: 0.6784


 45%|████▌     | 9/20 [00:32<00:37,  3.42s/it]

Epoch [9/20], Train Loss: 0.7542, Train Accuracy: 0.6616, Test Accuracy: 0.6761


 50%|█████     | 10/20 [00:35<00:34,  3.42s/it]

Epoch [10/20], Train Loss: 0.7667, Train Accuracy: 0.6551, Test Accuracy: 0.6726


 55%|█████▌    | 11/20 [00:39<00:31,  3.48s/it]

Epoch [11/20], Train Loss: 0.7597, Train Accuracy: 0.6551, Test Accuracy: 0.6714


 60%|██████    | 12/20 [00:42<00:28,  3.54s/it]

Epoch [12/20], Train Loss: 0.7436, Train Accuracy: 0.6642, Test Accuracy: 0.6784


 65%|██████▌   | 13/20 [00:45<00:23,  3.39s/it]

Epoch [13/20], Train Loss: 0.7472, Train Accuracy: 0.6607, Test Accuracy: 0.6019


 70%|███████   | 14/20 [00:50<00:22,  3.82s/it]

Epoch [14/20], Train Loss: 0.7866, Train Accuracy: 0.6454, Test Accuracy: 0.6643


 75%|███████▌  | 15/20 [00:56<00:22,  4.48s/it]

Epoch [15/20], Train Loss: 0.7413, Train Accuracy: 0.6607, Test Accuracy: 0.6726


 80%|████████  | 16/20 [01:00<00:17,  4.42s/it]

Epoch [16/20], Train Loss: 0.7634, Train Accuracy: 0.6633, Test Accuracy: 0.6737


 85%|████████▌ | 17/20 [01:04<00:12,  4.29s/it]

Epoch [17/20], Train Loss: 0.7322, Train Accuracy: 0.6695, Test Accuracy: 0.6113


 90%|█████████ | 18/20 [01:08<00:08,  4.05s/it]

Epoch [18/20], Train Loss: 0.7398, Train Accuracy: 0.6736, Test Accuracy: 0.6855


 95%|█████████▌| 19/20 [01:12<00:04,  4.17s/it]

Epoch [19/20], Train Loss: 0.7081, Train Accuracy: 0.6884, Test Accuracy: 0.6890


100%|██████████| 20/20 [01:17<00:00,  3.87s/it]

Epoch [20/20], Train Loss: 0.6873, Train Accuracy: 0.6943, Test Accuracy: 0.6890





Final Test Accuracy: 0.6890


  5%|▌         | 1/20 [00:04<01:28,  4.67s/it]

Epoch [1/20], Train Loss: 0.9653, Train Accuracy: 0.5649, Test Accuracy: 0.5901


 10%|█         | 2/20 [00:08<01:16,  4.24s/it]

Epoch [2/20], Train Loss: 0.8204, Train Accuracy: 0.6327, Test Accuracy: 0.6384


 15%|█▌        | 3/20 [00:12<01:07,  4.00s/it]

Epoch [3/20], Train Loss: 0.7892, Train Accuracy: 0.6486, Test Accuracy: 0.6455


 20%|██        | 4/20 [00:16<01:03,  3.95s/it]

Epoch [4/20], Train Loss: 0.7652, Train Accuracy: 0.6636, Test Accuracy: 0.6702


 25%|██▌       | 5/20 [00:19<00:57,  3.81s/it]

Epoch [5/20], Train Loss: 0.7777, Train Accuracy: 0.6495, Test Accuracy: 0.6431


 30%|███       | 6/20 [00:23<00:52,  3.78s/it]

Epoch [6/20], Train Loss: 0.7477, Train Accuracy: 0.6651, Test Accuracy: 0.6514


 35%|███▌      | 7/20 [00:27<00:49,  3.84s/it]

Epoch [7/20], Train Loss: 0.7466, Train Accuracy: 0.6675, Test Accuracy: 0.6254


 40%|████      | 8/20 [00:31<00:45,  3.77s/it]

Epoch [8/20], Train Loss: 0.7767, Train Accuracy: 0.6483, Test Accuracy: 0.6749


 45%|████▌     | 9/20 [00:34<00:39,  3.60s/it]

Epoch [9/20], Train Loss: 0.7447, Train Accuracy: 0.6728, Test Accuracy: 0.6408


 50%|█████     | 10/20 [00:37<00:34,  3.49s/it]

Epoch [10/20], Train Loss: 0.7500, Train Accuracy: 0.6692, Test Accuracy: 0.6549


 55%|█████▌    | 11/20 [00:41<00:32,  3.58s/it]

Epoch [11/20], Train Loss: 0.7424, Train Accuracy: 0.6730, Test Accuracy: 0.6125


 60%|██████    | 12/20 [00:45<00:29,  3.69s/it]

Epoch [12/20], Train Loss: 0.7471, Train Accuracy: 0.6574, Test Accuracy: 0.6584


 65%|██████▌   | 13/20 [00:49<00:26,  3.83s/it]

Epoch [13/20], Train Loss: 0.7649, Train Accuracy: 0.6616, Test Accuracy: 0.6655


 70%|███████   | 14/20 [00:53<00:22,  3.80s/it]

Epoch [14/20], Train Loss: 0.7334, Train Accuracy: 0.6745, Test Accuracy: 0.6643


 75%|███████▌  | 15/20 [00:57<00:19,  3.84s/it]

Epoch [15/20], Train Loss: 0.7368, Train Accuracy: 0.6748, Test Accuracy: 0.6466


 80%|████████  | 16/20 [01:00<00:14,  3.69s/it]

Epoch [16/20], Train Loss: 0.7433, Train Accuracy: 0.6692, Test Accuracy: 0.6690


 85%|████████▌ | 17/20 [01:03<00:10,  3.59s/it]

Epoch [17/20], Train Loss: 0.7367, Train Accuracy: 0.6730, Test Accuracy: 0.6620


 90%|█████████ | 18/20 [01:07<00:07,  3.64s/it]

Epoch [18/20], Train Loss: 0.7096, Train Accuracy: 0.6804, Test Accuracy: 0.6584


 95%|█████████▌| 19/20 [01:11<00:03,  3.67s/it]

Epoch [19/20], Train Loss: 0.6935, Train Accuracy: 0.6954, Test Accuracy: 0.6667


100%|██████████| 20/20 [01:14<00:00,  3.75s/it]

Epoch [20/20], Train Loss: 0.6820, Train Accuracy: 0.6957, Test Accuracy: 0.6667





Final Test Accuracy: 0.6667


  5%|▌         | 1/20 [00:04<01:16,  4.04s/it]

Epoch [1/20], Train Loss: 0.9515, Train Accuracy: 0.5620, Test Accuracy: 0.6231


 10%|█         | 2/20 [00:07<01:09,  3.87s/it]

Epoch [2/20], Train Loss: 0.7964, Train Accuracy: 0.6383, Test Accuracy: 0.6561


 15%|█▌        | 3/20 [00:11<01:01,  3.63s/it]

Epoch [3/20], Train Loss: 0.8012, Train Accuracy: 0.6433, Test Accuracy: 0.6184


 20%|██        | 4/20 [00:14<00:56,  3.56s/it]

Epoch [4/20], Train Loss: 0.7904, Train Accuracy: 0.6412, Test Accuracy: 0.6455


 25%|██▌       | 5/20 [00:18<00:53,  3.56s/it]

Epoch [5/20], Train Loss: 0.7753, Train Accuracy: 0.6539, Test Accuracy: 0.6584


 30%|███       | 6/20 [00:23<00:56,  4.01s/it]

Epoch [6/20], Train Loss: 0.7744, Train Accuracy: 0.6542, Test Accuracy: 0.6207


 35%|███▌      | 7/20 [00:27<00:52,  4.07s/it]

Epoch [7/20], Train Loss: 0.7814, Train Accuracy: 0.6563, Test Accuracy: 0.6667


 40%|████      | 8/20 [00:31<00:49,  4.09s/it]

Epoch [8/20], Train Loss: 0.7815, Train Accuracy: 0.6495, Test Accuracy: 0.6525


 45%|████▌     | 9/20 [00:34<00:43,  3.91s/it]

Epoch [9/20], Train Loss: 0.7668, Train Accuracy: 0.6592, Test Accuracy: 0.6608


 50%|█████     | 10/20 [00:38<00:39,  3.95s/it]

Epoch [10/20], Train Loss: 0.7623, Train Accuracy: 0.6548, Test Accuracy: 0.6796


 55%|█████▌    | 11/20 [00:42<00:34,  3.85s/it]

Epoch [11/20], Train Loss: 0.7773, Train Accuracy: 0.6465, Test Accuracy: 0.6643


 60%|██████    | 12/20 [00:46<00:30,  3.87s/it]

Epoch [12/20], Train Loss: 0.7664, Train Accuracy: 0.6577, Test Accuracy: 0.6784


 65%|██████▌   | 13/20 [00:50<00:27,  3.98s/it]

Epoch [13/20], Train Loss: 0.7659, Train Accuracy: 0.6545, Test Accuracy: 0.6714


 70%|███████   | 14/20 [00:54<00:23,  3.85s/it]

Epoch [14/20], Train Loss: 0.7568, Train Accuracy: 0.6680, Test Accuracy: 0.6655


 75%|███████▌  | 15/20 [00:58<00:20,  4.06s/it]

Epoch [15/20], Train Loss: 0.7486, Train Accuracy: 0.6530, Test Accuracy: 0.6832


 80%|████████  | 16/20 [01:02<00:15,  3.99s/it]

Epoch [16/20], Train Loss: 0.7488, Train Accuracy: 0.6630, Test Accuracy: 0.6643


 85%|████████▌ | 17/20 [01:06<00:12,  4.07s/it]

Epoch [17/20], Train Loss: 0.7488, Train Accuracy: 0.6725, Test Accuracy: 0.6313


 90%|█████████ | 18/20 [01:10<00:08,  4.01s/it]

Epoch [18/20], Train Loss: 0.7415, Train Accuracy: 0.6657, Test Accuracy: 0.6714


 95%|█████████▌| 19/20 [01:14<00:03,  3.88s/it]

Epoch [19/20], Train Loss: 0.7204, Train Accuracy: 0.6892, Test Accuracy: 0.6796


100%|██████████| 20/20 [01:18<00:00,  3.93s/it]

Epoch [20/20], Train Loss: 0.7028, Train Accuracy: 0.6839, Test Accuracy: 0.6855





Final Test Accuracy: 0.6855


  5%|▌         | 1/20 [00:05<01:35,  5.04s/it]

Epoch [1/20], Train Loss: 0.9723, Train Accuracy: 0.5605, Test Accuracy: 0.5995


 10%|█         | 2/20 [00:10<01:30,  5.02s/it]

Epoch [2/20], Train Loss: 0.8254, Train Accuracy: 0.6362, Test Accuracy: 0.6278


 15%|█▌        | 3/20 [00:14<01:19,  4.66s/it]

Epoch [3/20], Train Loss: 0.7949, Train Accuracy: 0.6427, Test Accuracy: 0.6502


 20%|██        | 4/20 [00:18<01:10,  4.39s/it]

Epoch [4/20], Train Loss: 0.7890, Train Accuracy: 0.6492, Test Accuracy: 0.6337


 25%|██▌       | 5/20 [00:22<01:04,  4.29s/it]

Epoch [5/20], Train Loss: 0.7955, Train Accuracy: 0.6345, Test Accuracy: 0.6490


 30%|███       | 6/20 [00:26<00:57,  4.13s/it]

Epoch [6/20], Train Loss: 0.7988, Train Accuracy: 0.6459, Test Accuracy: 0.6278


 35%|███▌      | 7/20 [00:30<00:54,  4.22s/it]

Epoch [7/20], Train Loss: 0.7657, Train Accuracy: 0.6568, Test Accuracy: 0.6631


 40%|████      | 8/20 [00:34<00:50,  4.24s/it]

Epoch [8/20], Train Loss: 0.7675, Train Accuracy: 0.6530, Test Accuracy: 0.6726


 45%|████▌     | 9/20 [00:38<00:44,  4.05s/it]

Epoch [9/20], Train Loss: 0.7532, Train Accuracy: 0.6642, Test Accuracy: 0.6667


 50%|█████     | 10/20 [00:42<00:41,  4.17s/it]

Epoch [10/20], Train Loss: 0.7766, Train Accuracy: 0.6557, Test Accuracy: 0.6443


 55%|█████▌    | 11/20 [00:46<00:35,  3.89s/it]

Epoch [11/20], Train Loss: 0.7678, Train Accuracy: 0.6589, Test Accuracy: 0.6655


 60%|██████    | 12/20 [00:50<00:31,  3.91s/it]

Epoch [12/20], Train Loss: 0.7593, Train Accuracy: 0.6530, Test Accuracy: 0.6631


 65%|██████▌   | 13/20 [00:53<00:26,  3.84s/it]

Epoch [13/20], Train Loss: 0.7500, Train Accuracy: 0.6663, Test Accuracy: 0.6113


 70%|███████   | 14/20 [00:57<00:23,  3.92s/it]

Epoch [14/20], Train Loss: 0.7435, Train Accuracy: 0.6610, Test Accuracy: 0.6466


 75%|███████▌  | 15/20 [01:01<00:19,  3.91s/it]

Epoch [15/20], Train Loss: 0.7537, Train Accuracy: 0.6610, Test Accuracy: 0.6572


 80%|████████  | 16/20 [01:05<00:15,  3.96s/it]

Epoch [16/20], Train Loss: 0.7413, Train Accuracy: 0.6730, Test Accuracy: 0.6620


 85%|████████▌ | 17/20 [01:09<00:11,  3.88s/it]

Epoch [17/20], Train Loss: 0.7454, Train Accuracy: 0.6660, Test Accuracy: 0.6596


 90%|█████████ | 18/20 [01:14<00:08,  4.11s/it]

Epoch [18/20], Train Loss: 0.7384, Train Accuracy: 0.6710, Test Accuracy: 0.6584


 95%|█████████▌| 19/20 [01:18<00:04,  4.11s/it]

Epoch [19/20], Train Loss: 0.7053, Train Accuracy: 0.6816, Test Accuracy: 0.6702


100%|██████████| 20/20 [01:22<00:00,  4.12s/it]

Epoch [20/20], Train Loss: 0.6958, Train Accuracy: 0.6892, Test Accuracy: 0.6737





Final Test Accuracy: 0.6737


  5%|▌         | 1/20 [00:03<01:00,  3.17s/it]

Epoch [1/20], Train Loss: 0.9606, Train Accuracy: 0.5645, Test Accuracy: 0.5708


 10%|█         | 2/20 [00:07<01:07,  3.76s/it]

Epoch [2/20], Train Loss: 0.8516, Train Accuracy: 0.6148, Test Accuracy: 0.6344


 15%|█▌        | 3/20 [00:11<01:03,  3.76s/it]

Epoch [3/20], Train Loss: 0.7983, Train Accuracy: 0.6458, Test Accuracy: 0.6486


 20%|██        | 4/20 [00:14<01:00,  3.78s/it]

Epoch [4/20], Train Loss: 0.7771, Train Accuracy: 0.6525, Test Accuracy: 0.6533


 25%|██▌       | 5/20 [00:18<00:54,  3.61s/it]

Epoch [5/20], Train Loss: 0.7848, Train Accuracy: 0.6528, Test Accuracy: 0.6616


 30%|███       | 6/20 [00:21<00:50,  3.62s/it]

Epoch [6/20], Train Loss: 0.7844, Train Accuracy: 0.6549, Test Accuracy: 0.6592


 35%|███▌      | 7/20 [00:25<00:45,  3.49s/it]

Epoch [7/20], Train Loss: 0.7739, Train Accuracy: 0.6537, Test Accuracy: 0.6757


 40%|████      | 8/20 [00:28<00:41,  3.45s/it]

Epoch [8/20], Train Loss: 0.7683, Train Accuracy: 0.6564, Test Accuracy: 0.6745


 45%|████▌     | 9/20 [00:32<00:38,  3.50s/it]

Epoch [9/20], Train Loss: 0.7908, Train Accuracy: 0.6428, Test Accuracy: 0.6781


 50%|█████     | 10/20 [00:36<00:39,  3.93s/it]

Epoch [10/20], Train Loss: 0.7636, Train Accuracy: 0.6599, Test Accuracy: 0.6132


 55%|█████▌    | 11/20 [00:42<00:40,  4.46s/it]

Epoch [11/20], Train Loss: 0.7666, Train Accuracy: 0.6608, Test Accuracy: 0.6686


 60%|██████    | 12/20 [00:46<00:34,  4.28s/it]

Epoch [12/20], Train Loss: 0.7572, Train Accuracy: 0.6634, Test Accuracy: 0.6474


 65%|██████▌   | 13/20 [00:49<00:28,  4.04s/it]

Epoch [13/20], Train Loss: 0.7450, Train Accuracy: 0.6658, Test Accuracy: 0.6403


 70%|███████   | 14/20 [00:53<00:23,  3.91s/it]

Epoch [14/20], Train Loss: 0.7637, Train Accuracy: 0.6655, Test Accuracy: 0.6675


 75%|███████▌  | 15/20 [00:57<00:19,  3.80s/it]

Epoch [15/20], Train Loss: 0.7650, Train Accuracy: 0.6643, Test Accuracy: 0.6415


 80%|████████  | 16/20 [01:00<00:15,  3.81s/it]

Epoch [16/20], Train Loss: 0.7647, Train Accuracy: 0.6602, Test Accuracy: 0.6757


 85%|████████▌ | 17/20 [01:05<00:12,  4.12s/it]

Epoch [17/20], Train Loss: 0.7384, Train Accuracy: 0.6734, Test Accuracy: 0.6616


 90%|█████████ | 18/20 [01:10<00:08,  4.24s/it]

Epoch [18/20], Train Loss: 0.7268, Train Accuracy: 0.6782, Test Accuracy: 0.6710


 95%|█████████▌| 19/20 [01:14<00:04,  4.29s/it]

Epoch [19/20], Train Loss: 0.7127, Train Accuracy: 0.6790, Test Accuracy: 0.6675


100%|██████████| 20/20 [01:19<00:00,  3.97s/it]

Epoch [20/20], Train Loss: 0.7060, Train Accuracy: 0.6932, Test Accuracy: 0.6757





Final Test Accuracy: 0.6757
Mean Accuracy: 0.6781 ± 0.0081
