In [1]:
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.model_selection import KFold

from data_loader import load_data, get_balanced_data, normalize_features

# import wandb
# wandb.init(project='gene')

data = load_data()
X_balanced, y_balanced = get_balanced_data(data)
FEATURE_TYPES = ['mcg', 'atac', 'hic', 'genebody']
for k, v in X_balanced.items():
    print(k, len(v))


Processed mcg data
Processed genebody data
Processed atac data
Processed hic data
zero: 11118, non-zero: 3467
mcg 5200
genebody 5200
atac 5200
hic 5200


In [4]:
HIDDEN_DIM = 16
NUM_LAYERS = 2
NUM_HEADS = 1
DROPOUT = 0.2
LR = 0.0002
OUTPUT_DIM = 3  # number of classes (-1, 0, 1)
NUM_EPOCHS = 20
BATCH_SIZE = 32

# HIDDEN_DIM = 64
# NUM_LAYERS = 1
# NUM_HEADS = 1
# DROPOUT = 0.3
# LR = 1e-4
# OUTPUT_DIM = 3  # number of classes (-1, 0, 1)
# NUM_EPOCHS = 200
# BATCH_SIZE = 16


class FourHeadTransformerModel(nn.Module):
    def __init__(self, mcg_input_dim, atac_input_dim, hic_input_dim, genebody_input_dim, hidden_dim, output_dim, num_layers=2, num_heads=1, dropout=0.1):
        super(FourHeadTransformerModel, self).__init__()
        self.mcg_embedding = nn.Linear(mcg_input_dim, hidden_dim)
        self.atac_embedding = nn.Linear(atac_input_dim, hidden_dim)
        self.hic_embedding = nn.Linear(hic_input_dim, hidden_dim)
        self.genebody_embedding = nn.Linear(genebody_input_dim, hidden_dim)

        # TODO: may need to use tanh in attention instead of softmax
        encoder_layers = nn.TransformerEncoderLayer(hidden_dim, num_heads, dim_feedforward=hidden_dim*2, dropout=dropout, batch_first=True, norm_first=True)
        self.mcg_transformer = nn.TransformerEncoder(encoder_layers, num_layers)
        self.atac_transformer = nn.TransformerEncoder(encoder_layers, num_layers)
        self.hic_transformer = nn.TransformerEncoder(encoder_layers, num_layers)
        self.genebody_transformer = nn.TransformerEncoder(encoder_layers, num_layers)
        
        self.classifier = nn.Linear(hidden_dim * 4, output_dim)
        self.hidden_dim = hidden_dim

    def forward(self, mcg_x, mcg_mask, atac_x, atac_mask, hic_x, hic_mask, genebody_x, genebody_mask):
        mcg_x = self.mcg_embedding(mcg_x)
        atac_x = self.atac_embedding(atac_x)
        hic_x = self.hic_embedding(hic_x)
        genebody_x = self.genebody_embedding(genebody_x)
        
        mcg_x = self.mcg_transformer(mcg_x, src_key_padding_mask=~mcg_mask.bool())
        atac_x = self.atac_transformer(atac_x, src_key_padding_mask=~atac_mask.bool())
        hic_x = self.hic_transformer(hic_x, src_key_padding_mask=~hic_mask.bool())
        genebody_x = self.genebody_transformer(genebody_x, src_key_padding_mask=~genebody_mask.bool())
        
        # Global average pooling
        mcg_x = mcg_x.mean(dim=1)
        atac_x = atac_x.mean(dim=1)
        hic_x = hic_x.mean(dim=1)
        genebody_x = genebody_x.mean(dim=1)

        # Concatenate MCG and ATAC embeddings
        combined_x = torch.cat((mcg_x, atac_x, hic_x, genebody_x), dim=1)
        
        output = self.classifier(combined_x)
        return output

class CombinedGeneDataset(Dataset):
    def __init__(self, mcg_data, atac_data, hic_data, genebody_data, labels):
        self.mcg_data = mcg_data
        self.atac_data = atac_data
        self.hic_data = hic_data
        self.genebody_data = genebody_data
        self.labels = labels
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        mcg_gene_data = torch.FloatTensor(self.mcg_data[idx])
        atac_gene_data = torch.FloatTensor(self.atac_data[idx])
        hic_gene_data = torch.FloatTensor(self.hic_data[idx])
        genebody_gene_data = torch.FloatTensor(self.genebody_data[idx])
        label = torch.LongTensor([self.labels[idx] + 1])  # Add 1 to shift labels to 0, 1, 2
        mcg_mask = torch.ones(len(mcg_gene_data))
        atac_mask = torch.ones(len(atac_gene_data))
        hic_mask = torch.ones(len(hic_gene_data))
        genebody_mask = torch.ones(len(genebody_gene_data))
        return mcg_gene_data, atac_gene_data, hic_gene_data, genebody_gene_data, label, mcg_mask, atac_mask, hic_mask, genebody_mask

def combined_collate_fn(batch):
    batch.sort(key=lambda x: len(x[0]), reverse=True)
    mcg_sequences, atac_sequences, hic_sequences, genebody_sequences, labels, mcg_masks, atac_masks, hic_masks, genebody_masks = zip(*batch)
    
    mcg_lengths = [len(seq) for seq in mcg_sequences]
    atac_lengths = [len(seq) for seq in atac_sequences]
    hic_lengths = [len(seq) for seq in hic_sequences]
    genebody_lengths = [len(seq) for seq in genebody_sequences]
    mcg_max_len = max(mcg_lengths)
    atac_max_len = max(atac_lengths)
    hic_max_len = max(hic_lengths)
    genebody_max_len = max(genebody_lengths)
    
    padded_mcg_seqs = torch.zeros(len(mcg_sequences), mcg_max_len, mcg_sequences[0].size(1))
    padded_atac_seqs = torch.zeros(len(atac_sequences), atac_max_len, atac_sequences[0].size(1))
    padded_hic_seqs = torch.zeros(len(hic_sequences), hic_max_len, hic_sequences[0].size(1))
    padded_genebody_seqs = torch.zeros(len(genebody_sequences), genebody_max_len, genebody_sequences[0].size(1))
    padded_mcg_masks = torch.zeros(len(mcg_sequences), mcg_max_len)
    padded_atac_masks = torch.zeros(len(atac_sequences), atac_max_len)
    padded_hic_masks = torch.zeros(len(hic_sequences), hic_max_len)
    padded_genebody_masks = torch.zeros(len(genebody_sequences), genebody_max_len)
    
    for i, (mcg_seq, atac_seq, hic_seq, genebody_seq, mcg_length, atac_length, hic_length, genebody_length) in enumerate(zip(mcg_sequences, atac_sequences, hic_sequences, genebody_sequences, mcg_lengths, atac_lengths, hic_lengths, genebody_lengths)):
        padded_mcg_seqs[i, :mcg_length] = mcg_seq
        padded_atac_seqs[i, :atac_length] = atac_seq
        padded_hic_seqs[i, :hic_length] = hic_seq
        padded_genebody_seqs[i, :genebody_length] = genebody_seq
        padded_mcg_masks[i, :mcg_length] = 1
        padded_atac_masks[i, :atac_length] = 1
        padded_hic_masks[i, :hic_length] = 1
        padded_genebody_masks[i, :genebody_length] = 1
    
    return padded_mcg_seqs, padded_atac_seqs, padded_hic_seqs, padded_genebody_seqs, torch.cat(labels), padded_mcg_masks, padded_atac_masks, padded_hic_masks, padded_genebody_masks

def train_combined_model(X_train_mcg, X_train_atac, X_train_hic, X_train_genebody, y_train, X_test_mcg, X_test_atac, X_test_hic, X_test_genebody, y_test, exp_name, fold_idx):
    #wandb.init(project='gene', group=exp_name, name=f'fold-{fold_idx}')
    mcg_input_dim = len(X_train_mcg[0][0])
    atac_input_dim = len(X_train_atac[0][0])
    hic_input_dim = len(X_train_hic[0][0])
    genebody_input_dim = len(X_train_genebody[0][0])
    
    train_dataset = CombinedGeneDataset(X_train_mcg, X_train_atac, X_train_hic, X_train_genebody, y_train)
    test_dataset = CombinedGeneDataset(X_test_mcg, X_test_atac, X_test_hic, X_test_genebody, y_test)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=combined_collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=combined_collate_fn)

    model = FourHeadTransformerModel(mcg_input_dim, atac_input_dim, hic_input_dim, genebody_input_dim, HIDDEN_DIM, OUTPUT_DIM, num_layers=NUM_LAYERS, num_heads=NUM_HEADS, dropout=DROPOUT)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LR)

    lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=0.0001)
    # Create the OneCycleLR scheduler
    # lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LR, total_steps=NUM_EPOCHS,
    #                           pct_start=0.8, anneal_strategy='cos',
    #                           cycle_momentum=False, div_factor=5.0,
    #                           final_div_factor=10.0)

    for epoch in tqdm(range(NUM_EPOCHS)):
        model.train()
        total_loss = 0
        train_correct = 0
        train_total = 0
        for mcg_x, atac_x, hic_x, genebody_x, batch_y, mcg_mask, atac_mask, hic_mask, genebody_mask in train_loader:
            optimizer.zero_grad()
            outputs = model(mcg_x, mcg_mask, atac_x, atac_mask, hic_x, hic_mask, genebody_x, genebody_mask)
            loss = criterion(outputs, batch_y.squeeze())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            train_correct += (outputs.argmax(dim=1) == batch_y.squeeze()).sum().item()
            train_total += batch_y.size(0)
        lr_scheduler.step()

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for mcg_x, atac_x, hic_x, genebody_x, batch_y, mcg_mask, atac_mask, hic_mask, genebody_mask in test_loader:
                outputs = model(mcg_x, mcg_mask, atac_x, atac_mask, hic_x, hic_mask, genebody_x, genebody_mask)
                _, predicted = torch.max(outputs.data, 1)
                total += batch_y.size(0)
                correct += (predicted == batch_y.squeeze()).sum().item()
        
        accuracy = correct / total
        print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Train Loss: {total_loss/len(train_loader):.4f}, Train Accuracy: {train_correct/train_total:.4f}, Test Accuracy: {accuracy:.4f}')
        #wandb.log({'epoch': epoch, 'LR': optimizer.param_groups[0]['lr'], 'train_loss': total_loss/len(train_loader), 'train_accuracy': train_correct/train_total, 'test_accuracy': accuracy})

    model.eval()
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for mcg_x, atac_x, hic_x, genebody_x, batch_y, mcg_mask, atac_mask, hic_mask, genebody_mask in test_loader:
            outputs = model(mcg_x, mcg_mask, atac_x, atac_mask, hic_x, hic_mask, genebody_x, genebody_mask)
            _, predicted = torch.max(outputs.data, 1)
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(batch_y.cpu().numpy())

    final_accuracy = sum(np.array(all_predictions) == np.array(all_labels).squeeze()) / len(all_labels)
    print(f'Final Test Accuracy: {final_accuracy:.4f}')
    return final_accuracy

In [5]:
import time
exp_name = f'two-head-{time.strftime("%Y%m%d-%H%M%S")}'

# wandb.config.update({
#     'hidden_dim': HIDDEN_DIM,
#     'num_layers': NUM_LAYERS,
#     'num_heads': NUM_HEADS,
#     'dropout': DROPOUT,
#     'lr': LR,
#     'output_dim': OUTPUT_DIM,
#     'num_epochs': NUM_EPOCHS
# })


kf = KFold(n_splits=5, shuffle=True, random_state=25)
accuracies = []
for i, (train_index, test_index) in enumerate(kf.split(X_balanced['mcg'])):
    X_train_mcg, X_test_mcg = [X_balanced['mcg'][i] for i in train_index], [X_balanced['mcg'][i] for i in test_index]
    X_train_atac, X_test_atac = [X_balanced['atac'][i] for i in train_index], [X_balanced['atac'][i] for i in test_index]
    X_train_hic, X_test_hic = [X_balanced['hic'][i] for i in train_index], [X_balanced['hic'][i] for i in test_index]
    X_train_genebody, X_test_genebody = [X_balanced['genebody'][i] for i in train_index], [X_balanced['genebody'][i] for i in test_index]
    y_train, y_test = [y_balanced[i] for i in train_index], [y_balanced[i] for i in test_index]
    
    X_train_mcg_normalized, X_test_mcg_normalized = normalize_features(X_train_mcg, X_test_mcg)
    X_train_atac_normalized, X_test_atac_normalized = normalize_features(X_train_atac, X_test_atac)
    X_train_hic_normalized, X_test_hic_normalized = normalize_features(X_train_hic, X_test_hic)
    X_train_genebody_normalized, X_test_genebody_normalized = normalize_features(X_train_genebody, X_test_genebody)
    
    accuracies.append(train_combined_model(X_train_mcg_normalized, X_train_atac_normalized, X_train_hic_normalized, X_train_genebody_normalized, y_train, 
                                           X_test_mcg_normalized, X_test_atac_normalized, X_test_hic_normalized, X_test_genebody_normalized, y_test, exp_name=exp_name, fold_idx=i))

print(f'Mean Accuracy: {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}')

  0%|          | 0/20 [00:00<?, ?it/s]

  5%|▌         | 1/20 [00:02<00:50,  2.65s/it]

Epoch [1/20], Train Loss: 1.0468, Train Accuracy: 0.4454, Test Accuracy: 0.4865


 10%|█         | 2/20 [00:05<00:47,  2.62s/it]

Epoch [2/20], Train Loss: 1.0004, Train Accuracy: 0.4877, Test Accuracy: 0.4933


 15%|█▌        | 3/20 [00:07<00:44,  2.63s/it]

Epoch [3/20], Train Loss: 0.9739, Train Accuracy: 0.5022, Test Accuracy: 0.5317


 20%|██        | 4/20 [00:10<00:41,  2.61s/it]

Epoch [4/20], Train Loss: 0.9609, Train Accuracy: 0.5168, Test Accuracy: 0.5327


 25%|██▌       | 5/20 [00:13<00:38,  2.60s/it]

Epoch [5/20], Train Loss: 0.9414, Train Accuracy: 0.5298, Test Accuracy: 0.5404


 30%|███       | 6/20 [00:15<00:36,  2.58s/it]

Epoch [6/20], Train Loss: 0.9255, Train Accuracy: 0.5423, Test Accuracy: 0.5673


 35%|███▌      | 7/20 [00:18<00:33,  2.58s/it]

Epoch [7/20], Train Loss: 0.8994, Train Accuracy: 0.5608, Test Accuracy: 0.5788


 40%|████      | 8/20 [00:20<00:31,  2.60s/it]

Epoch [8/20], Train Loss: 0.8908, Train Accuracy: 0.5680, Test Accuracy: 0.5808


 45%|████▌     | 9/20 [00:23<00:28,  2.59s/it]

Epoch [9/20], Train Loss: 0.8854, Train Accuracy: 0.5659, Test Accuracy: 0.5817


 50%|█████     | 10/20 [00:25<00:25,  2.59s/it]

Epoch [10/20], Train Loss: 0.8794, Train Accuracy: 0.5704, Test Accuracy: 0.5788


 55%|█████▌    | 11/20 [00:28<00:23,  2.58s/it]

Epoch [11/20], Train Loss: 0.8719, Train Accuracy: 0.5700, Test Accuracy: 0.5740


 60%|██████    | 12/20 [00:31<00:20,  2.57s/it]

Epoch [12/20], Train Loss: 0.8720, Train Accuracy: 0.5781, Test Accuracy: 0.5788


 65%|██████▌   | 13/20 [00:33<00:18,  2.59s/it]

Epoch [13/20], Train Loss: 0.8730, Train Accuracy: 0.5779, Test Accuracy: 0.5981


 70%|███████   | 14/20 [00:36<00:15,  2.58s/it]

Epoch [14/20], Train Loss: 0.8721, Train Accuracy: 0.5793, Test Accuracy: 0.5962


 75%|███████▌  | 15/20 [00:38<00:12,  2.58s/it]

Epoch [15/20], Train Loss: 0.8701, Train Accuracy: 0.5810, Test Accuracy: 0.5913


 80%|████████  | 16/20 [00:41<00:10,  2.58s/it]

Epoch [16/20], Train Loss: 0.8738, Train Accuracy: 0.5678, Test Accuracy: 0.5923


 85%|████████▌ | 17/20 [00:44<00:07,  2.59s/it]

Epoch [17/20], Train Loss: 0.8672, Train Accuracy: 0.5760, Test Accuracy: 0.5990


 90%|█████████ | 18/20 [00:46<00:05,  2.68s/it]

Epoch [18/20], Train Loss: 0.8680, Train Accuracy: 0.5829, Test Accuracy: 0.5942


 95%|█████████▌| 19/20 [00:49<00:02,  2.76s/it]

Epoch [19/20], Train Loss: 0.8610, Train Accuracy: 0.5764, Test Accuracy: 0.6010


100%|██████████| 20/20 [00:52<00:00,  2.62s/it]

Epoch [20/20], Train Loss: 0.8622, Train Accuracy: 0.5776, Test Accuracy: 0.5990
Final Test Accuracy: 0.5990



  5%|▌         | 1/20 [00:02<00:49,  2.58s/it]

Epoch [1/20], Train Loss: 1.0580, Train Accuracy: 0.4596, Test Accuracy: 0.4683


 10%|█         | 2/20 [00:05<00:46,  2.58s/it]

Epoch [2/20], Train Loss: 0.9881, Train Accuracy: 0.4921, Test Accuracy: 0.4837


 15%|█▌        | 3/20 [00:07<00:44,  2.59s/it]

Epoch [3/20], Train Loss: 0.9618, Train Accuracy: 0.5079, Test Accuracy: 0.5087


 20%|██        | 4/20 [00:10<00:41,  2.58s/it]

Epoch [4/20], Train Loss: 0.9515, Train Accuracy: 0.5137, Test Accuracy: 0.5327


 25%|██▌       | 5/20 [00:12<00:38,  2.57s/it]

Epoch [5/20], Train Loss: 0.9405, Train Accuracy: 0.5212, Test Accuracy: 0.5385


 30%|███       | 6/20 [00:15<00:36,  2.57s/it]

Epoch [6/20], Train Loss: 0.9281, Train Accuracy: 0.5320, Test Accuracy: 0.5462


 35%|███▌      | 7/20 [00:18<00:33,  2.59s/it]

Epoch [7/20], Train Loss: 0.9043, Train Accuracy: 0.5531, Test Accuracy: 0.5721


 40%|████      | 8/20 [00:20<00:31,  2.59s/it]

Epoch [8/20], Train Loss: 0.8940, Train Accuracy: 0.5623, Test Accuracy: 0.5663


 45%|████▌     | 9/20 [00:23<00:28,  2.59s/it]

Epoch [9/20], Train Loss: 0.8888, Train Accuracy: 0.5661, Test Accuracy: 0.5788


 50%|█████     | 10/20 [00:25<00:26,  2.61s/it]

Epoch [10/20], Train Loss: 0.8910, Train Accuracy: 0.5632, Test Accuracy: 0.5721


 55%|█████▌    | 11/20 [00:28<00:23,  2.61s/it]

Epoch [11/20], Train Loss: 0.8888, Train Accuracy: 0.5596, Test Accuracy: 0.5625


 60%|██████    | 12/20 [00:31<00:20,  2.60s/it]

Epoch [12/20], Train Loss: 0.8806, Train Accuracy: 0.5637, Test Accuracy: 0.5712


 65%|██████▌   | 13/20 [00:33<00:18,  2.60s/it]

Epoch [13/20], Train Loss: 0.8799, Train Accuracy: 0.5625, Test Accuracy: 0.5673


 70%|███████   | 14/20 [00:36<00:15,  2.60s/it]

Epoch [14/20], Train Loss: 0.8770, Train Accuracy: 0.5743, Test Accuracy: 0.5577


 75%|███████▌  | 15/20 [00:38<00:12,  2.59s/it]

Epoch [15/20], Train Loss: 0.8818, Train Accuracy: 0.5584, Test Accuracy: 0.5740


 80%|████████  | 16/20 [00:41<00:10,  2.61s/it]

Epoch [16/20], Train Loss: 0.8765, Train Accuracy: 0.5649, Test Accuracy: 0.5596


 85%|████████▌ | 17/20 [00:44<00:07,  2.61s/it]

Epoch [17/20], Train Loss: 0.8767, Train Accuracy: 0.5745, Test Accuracy: 0.5538


 90%|█████████ | 18/20 [00:46<00:05,  2.62s/it]

Epoch [18/20], Train Loss: 0.8727, Train Accuracy: 0.5755, Test Accuracy: 0.5519


 95%|█████████▌| 19/20 [00:49<00:02,  2.59s/it]

Epoch [19/20], Train Loss: 0.8721, Train Accuracy: 0.5702, Test Accuracy: 0.5808


100%|██████████| 20/20 [00:51<00:00,  2.59s/it]

Epoch [20/20], Train Loss: 0.8700, Train Accuracy: 0.5721, Test Accuracy: 0.5712





Final Test Accuracy: 0.5712


  5%|▌         | 1/20 [00:03<00:58,  3.08s/it]

Epoch [1/20], Train Loss: 1.0276, Train Accuracy: 0.4738, Test Accuracy: 0.4740


 10%|█         | 2/20 [00:05<00:50,  2.82s/it]

Epoch [2/20], Train Loss: 1.0079, Train Accuracy: 0.4930, Test Accuracy: 0.4798


 15%|█▌        | 3/20 [00:08<00:45,  2.70s/it]

Epoch [3/20], Train Loss: 0.9936, Train Accuracy: 0.5002, Test Accuracy: 0.4837


 20%|██        | 4/20 [00:10<00:42,  2.64s/it]

Epoch [4/20], Train Loss: 0.9795, Train Accuracy: 0.5050, Test Accuracy: 0.5048


 25%|██▌       | 5/20 [00:13<00:38,  2.59s/it]

Epoch [5/20], Train Loss: 0.9462, Train Accuracy: 0.5327, Test Accuracy: 0.5442


 30%|███       | 6/20 [00:15<00:36,  2.58s/it]

Epoch [6/20], Train Loss: 0.9168, Train Accuracy: 0.5476, Test Accuracy: 0.5644


 35%|███▌      | 7/20 [00:18<00:33,  2.57s/it]

Epoch [7/20], Train Loss: 0.9087, Train Accuracy: 0.5538, Test Accuracy: 0.5452


 40%|████      | 8/20 [00:21<00:30,  2.58s/it]

Epoch [8/20], Train Loss: 0.8975, Train Accuracy: 0.5637, Test Accuracy: 0.5712


 45%|████▌     | 9/20 [00:23<00:28,  2.58s/it]

Epoch [9/20], Train Loss: 0.8902, Train Accuracy: 0.5608, Test Accuracy: 0.5654


 50%|█████     | 10/20 [00:26<00:25,  2.57s/it]

Epoch [10/20], Train Loss: 0.8927, Train Accuracy: 0.5572, Test Accuracy: 0.5673


 55%|█████▌    | 11/20 [00:28<00:23,  2.56s/it]

Epoch [11/20], Train Loss: 0.8853, Train Accuracy: 0.5738, Test Accuracy: 0.5654


 60%|██████    | 12/20 [00:31<00:20,  2.55s/it]

Epoch [12/20], Train Loss: 0.8813, Train Accuracy: 0.5534, Test Accuracy: 0.5683


 65%|██████▌   | 13/20 [00:33<00:17,  2.55s/it]

Epoch [13/20], Train Loss: 0.8786, Train Accuracy: 0.5712, Test Accuracy: 0.5663


 70%|███████   | 14/20 [00:36<00:15,  2.54s/it]

Epoch [14/20], Train Loss: 0.8786, Train Accuracy: 0.5740, Test Accuracy: 0.5702


 75%|███████▌  | 15/20 [00:38<00:12,  2.54s/it]

Epoch [15/20], Train Loss: 0.8780, Train Accuracy: 0.5822, Test Accuracy: 0.5731


 80%|████████  | 16/20 [00:41<00:10,  2.55s/it]

Epoch [16/20], Train Loss: 0.8740, Train Accuracy: 0.5760, Test Accuracy: 0.5721


 85%|████████▌ | 17/20 [00:43<00:07,  2.54s/it]

Epoch [17/20], Train Loss: 0.8702, Train Accuracy: 0.5803, Test Accuracy: 0.5635


 90%|█████████ | 18/20 [00:46<00:05,  2.54s/it]

Epoch [18/20], Train Loss: 0.8713, Train Accuracy: 0.5733, Test Accuracy: 0.5654


 95%|█████████▌| 19/20 [00:48<00:02,  2.54s/it]

Epoch [19/20], Train Loss: 0.8754, Train Accuracy: 0.5757, Test Accuracy: 0.5731


100%|██████████| 20/20 [00:51<00:00,  2.58s/it]

Epoch [20/20], Train Loss: 0.8744, Train Accuracy: 0.5760, Test Accuracy: 0.5673
Final Test Accuracy: 0.5673



  5%|▌         | 1/20 [00:02<00:47,  2.51s/it]

Epoch [1/20], Train Loss: 1.0520, Train Accuracy: 0.4606, Test Accuracy: 0.4913


 10%|█         | 2/20 [00:05<00:45,  2.51s/it]

Epoch [2/20], Train Loss: 1.0200, Train Accuracy: 0.4740, Test Accuracy: 0.5048


 15%|█▌        | 3/20 [00:07<00:45,  2.70s/it]

Epoch [3/20], Train Loss: 1.0007, Train Accuracy: 0.4909, Test Accuracy: 0.5029


 20%|██        | 4/20 [00:10<00:45,  2.83s/it]

Epoch [4/20], Train Loss: 0.9842, Train Accuracy: 0.4998, Test Accuracy: 0.4952


 25%|██▌       | 5/20 [00:13<00:41,  2.74s/it]

Epoch [5/20], Train Loss: 0.9700, Train Accuracy: 0.5132, Test Accuracy: 0.4971


 30%|███       | 6/20 [00:16<00:37,  2.67s/it]

Epoch [6/20], Train Loss: 0.9579, Train Accuracy: 0.5111, Test Accuracy: 0.5269


 35%|███▌      | 7/20 [00:18<00:34,  2.64s/it]

Epoch [7/20], Train Loss: 0.9391, Train Accuracy: 0.5233, Test Accuracy: 0.5154


 40%|████      | 8/20 [00:21<00:31,  2.62s/it]

Epoch [8/20], Train Loss: 0.9301, Train Accuracy: 0.5327, Test Accuracy: 0.5587


 45%|████▌     | 9/20 [00:23<00:28,  2.60s/it]

Epoch [9/20], Train Loss: 0.9274, Train Accuracy: 0.5339, Test Accuracy: 0.5442


 50%|█████     | 10/20 [00:26<00:25,  2.59s/it]

Epoch [10/20], Train Loss: 0.9104, Train Accuracy: 0.5526, Test Accuracy: 0.5587


 55%|█████▌    | 11/20 [00:28<00:23,  2.58s/it]

Epoch [11/20], Train Loss: 0.9010, Train Accuracy: 0.5555, Test Accuracy: 0.5683


 60%|██████    | 12/20 [00:31<00:20,  2.56s/it]

Epoch [12/20], Train Loss: 0.9051, Train Accuracy: 0.5618, Test Accuracy: 0.5596


 65%|██████▌   | 13/20 [00:33<00:17,  2.54s/it]

Epoch [13/20], Train Loss: 0.8892, Train Accuracy: 0.5714, Test Accuracy: 0.5673


 70%|███████   | 14/20 [00:36<00:15,  2.54s/it]

Epoch [14/20], Train Loss: 0.8951, Train Accuracy: 0.5654, Test Accuracy: 0.5663


 75%|███████▌  | 15/20 [00:39<00:12,  2.54s/it]

Epoch [15/20], Train Loss: 0.8878, Train Accuracy: 0.5709, Test Accuracy: 0.5596


 80%|████████  | 16/20 [00:41<00:10,  2.55s/it]

Epoch [16/20], Train Loss: 0.8879, Train Accuracy: 0.5644, Test Accuracy: 0.5798


 85%|████████▌ | 17/20 [00:44<00:07,  2.54s/it]

Epoch [17/20], Train Loss: 0.8855, Train Accuracy: 0.5644, Test Accuracy: 0.5663


 90%|█████████ | 18/20 [00:46<00:05,  2.55s/it]

Epoch [18/20], Train Loss: 0.8829, Train Accuracy: 0.5721, Test Accuracy: 0.5615


 95%|█████████▌| 19/20 [00:49<00:02,  2.54s/it]

Epoch [19/20], Train Loss: 0.8820, Train Accuracy: 0.5748, Test Accuracy: 0.5673


100%|██████████| 20/20 [00:51<00:00,  2.59s/it]

Epoch [20/20], Train Loss: 0.8788, Train Accuracy: 0.5779, Test Accuracy: 0.5644
Final Test Accuracy: 0.5644



  5%|▌         | 1/20 [00:02<00:48,  2.56s/it]

Epoch [1/20], Train Loss: 1.0407, Train Accuracy: 0.4481, Test Accuracy: 0.5077


 10%|█         | 2/20 [00:05<00:45,  2.55s/it]

Epoch [2/20], Train Loss: 1.0160, Train Accuracy: 0.4762, Test Accuracy: 0.5240


 15%|█▌        | 3/20 [00:07<00:43,  2.57s/it]

Epoch [3/20], Train Loss: 1.0030, Train Accuracy: 0.4839, Test Accuracy: 0.5317


 20%|██        | 4/20 [00:10<00:41,  2.58s/it]

Epoch [4/20], Train Loss: 0.9902, Train Accuracy: 0.4918, Test Accuracy: 0.5346


 25%|██▌       | 5/20 [00:12<00:38,  2.57s/it]

Epoch [5/20], Train Loss: 0.9701, Train Accuracy: 0.5053, Test Accuracy: 0.5500


 30%|███       | 6/20 [00:15<00:36,  2.57s/it]

Epoch [6/20], Train Loss: 0.9589, Train Accuracy: 0.5103, Test Accuracy: 0.5500


 35%|███▌      | 7/20 [00:18<00:35,  2.74s/it]

Epoch [7/20], Train Loss: 0.9522, Train Accuracy: 0.5139, Test Accuracy: 0.5673


 40%|████      | 8/20 [00:21<00:32,  2.74s/it]

Epoch [8/20], Train Loss: 0.9377, Train Accuracy: 0.5310, Test Accuracy: 0.5760


 45%|████▌     | 9/20 [00:23<00:29,  2.67s/it]

Epoch [9/20], Train Loss: 0.9271, Train Accuracy: 0.5368, Test Accuracy: 0.5779


 50%|█████     | 10/20 [00:26<00:26,  2.64s/it]

Epoch [10/20], Train Loss: 0.9225, Train Accuracy: 0.5486, Test Accuracy: 0.5769


 55%|█████▌    | 11/20 [00:28<00:23,  2.61s/it]

Epoch [11/20], Train Loss: 0.9134, Train Accuracy: 0.5450, Test Accuracy: 0.5913


 60%|██████    | 12/20 [00:31<00:20,  2.61s/it]

Epoch [12/20], Train Loss: 0.9000, Train Accuracy: 0.5512, Test Accuracy: 0.5798


 65%|██████▌   | 13/20 [00:34<00:18,  2.58s/it]

Epoch [13/20], Train Loss: 0.9018, Train Accuracy: 0.5620, Test Accuracy: 0.5788


 70%|███████   | 14/20 [00:36<00:15,  2.58s/it]

Epoch [14/20], Train Loss: 0.8918, Train Accuracy: 0.5625, Test Accuracy: 0.5962


 75%|███████▌  | 15/20 [00:39<00:12,  2.58s/it]

Epoch [15/20], Train Loss: 0.8877, Train Accuracy: 0.5625, Test Accuracy: 0.5904


 80%|████████  | 16/20 [00:41<00:10,  2.57s/it]

Epoch [16/20], Train Loss: 0.8943, Train Accuracy: 0.5534, Test Accuracy: 0.5962


 85%|████████▌ | 17/20 [00:44<00:07,  2.56s/it]

Epoch [17/20], Train Loss: 0.8954, Train Accuracy: 0.5577, Test Accuracy: 0.5952


 90%|█████████ | 18/20 [00:46<00:05,  2.55s/it]

Epoch [18/20], Train Loss: 0.8935, Train Accuracy: 0.5601, Test Accuracy: 0.6019


 95%|█████████▌| 19/20 [00:49<00:02,  2.56s/it]

Epoch [19/20], Train Loss: 0.8858, Train Accuracy: 0.5620, Test Accuracy: 0.6038


100%|██████████| 20/20 [00:51<00:00,  2.60s/it]

Epoch [20/20], Train Loss: 0.8839, Train Accuracy: 0.5594, Test Accuracy: 0.6077
Final Test Accuracy: 0.6077
Mean Accuracy: 0.5819 ± 0.0178





In [6]:
print(f'Mean Accuracy: {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}')

Mean Accuracy: 0.5819 ± 0.0178
