<a href="https://colab.research.google.com/github/K-Musty/Keyword-Spotting-System/blob/main/adkws.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
import librosa
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from tqdm import tqdm
from collections import Counter
import numpy as np


In [None]:
!wget http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz
!mkdir -p speech_commands
!tar -xzf speech_commands_v0.02.tar.gz -C speech_commands


--2025-07-03 14:05:10--  http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz
Resolving download.tensorflow.org (download.tensorflow.org)... 172.253.118.207, 172.217.194.207, 142.251.10.207, ...
Connecting to download.tensorflow.org (download.tensorflow.org)|172.253.118.207|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2428923189 (2.3G) [application/gzip]
Saving to: ‘speech_commands_v0.02.tar.gz’


2025-07-03 14:07:05 (20.4 MB/s) - ‘speech_commands_v0.02.tar.gz’ saved [2428923189/2428923189]



In [None]:
# --- Device setup ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [None]:
# --- Constants ---
CORE_KEYWORDS = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go']
UNKNOWN_KEYWORDS = ['bed', 'bird', 'cat', 'dog', 'happy', 'house', 'marvin', 'sheila', 'tree', 'wow']
SAMPLE_RATE = 16000
AUDIO_LENGTH = SAMPLE_RATE  # 1 second


In [None]:
# --- Data loading and MFCC extraction ---
class MFCCTransform:
    def __call__(self, waveform):
        waveform = waveform.numpy().squeeze()
        mfcc = librosa.feature.mfcc(
            y=waveform,
            sr=SAMPLE_RATE,
            n_mfcc=40,
            n_fft=400,
            hop_length=160,
            win_length=400
        )
        return torch.from_numpy(mfcc).float()

class SpeechCommandsDataset(Dataset):
    def __init__(self, keywords, root_dir='speech_commands', transform=None):
        self.keywords = keywords
        self.root_dir = root_dir
        self.transform = transform
        self.file_list = []
        self.labels = []
        self.label_to_idx = {label: idx for idx, label in enumerate(keywords)}

        for label in keywords:
            label_dir = os.path.join(root_dir, label)
            if os.path.isdir(label_dir):
                for file in os.listdir(label_dir):
                    if file.endswith('.wav'):
                        self.file_list.append(os.path.join(label_dir, file))
                        self.labels.append(self.label_to_idx[label])

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        filepath = self.file_list[idx]
        label = self.labels[idx]
        waveform, sr = torchaudio.load(filepath)
        if sr != SAMPLE_RATE:
            resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE)
            waveform = resampler(waveform)
        waveform = waveform.mean(dim=0, keepdim=True)  # mono
        if waveform.shape[1] < AUDIO_LENGTH:
            waveform = F.pad(waveform, (0, AUDIO_LENGTH - waveform.shape[1]))
        else:
            waveform = waveform[:, :AUDIO_LENGTH]

        if self.transform:
            features = self.transform(waveform)
            return features, label
        else:
            return waveform, label

# --- TDResNet7 Encoder (Archit & Lee) ---
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, dilation=1):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=7,
                               stride=stride, padding=3*dilation,
                               dilation=dilation, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=7,
                               stride=1, padding=3*dilation,
                               dilation=dilation, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm1d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class TDResNet7(nn.Module):
    def __init__(self, input_dim=40, hidden_dim=64):
        super().__init__()
        self.conv1 = nn.Conv1d(input_dim, hidden_dim, kernel_size=7,
                               stride=1, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(hidden_dim)

        self.layer1 = self._make_layer(hidden_dim, hidden_dim, stride=1, dilation=1)
        self.layer2 = self._make_layer(hidden_dim, hidden_dim, stride=1, dilation=2)
        self.layer3 = self._make_layer(hidden_dim, hidden_dim, stride=1, dilation=4)

        self.pool = nn.AdaptiveAvgPool1d(1)

    def _make_layer(self, in_channels, out_channels, stride, dilation):
        return ResidualBlock(in_channels, out_channels, stride, dilation)

    def forward(self, x):
        if x.dim() == 2:
            x = x.unsqueeze(0)  # batch dim

        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.pool(x)
        return x.squeeze(-1)  # (batch, channels)


In [None]:
# --- Attention modules ---
class SelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.scale = embed_dim ** -0.5

    def forward(self, support_embeddings):
        # support_embeddings: (n_way, n_shot, embed_dim)
        Q = self.query(support_embeddings)
        K = self.key(support_embeddings)
        V = self.value(support_embeddings)

        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        attn_weights = F.softmax(attn_scores, dim=-1)
        attended = torch.matmul(attn_weights, V)
        refined = attended.mean(dim=1)
        return refined, attn_weights

class CrossAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.query_proj = nn.Linear(embed_dim, embed_dim)
        self.key_proj = nn.Linear(embed_dim, embed_dim)
        self.value_proj = nn.Linear(embed_dim, embed_dim)
        self.scale = embed_dim ** -0.5

    def forward(self, query_embedding, support_prototypes):
        # query_embedding: (batch, embed_dim)
        # support_prototypes: (n_way, embed_dim)
        batch_size = query_embedding.size(0)
        Q = self.query_proj(query_embedding).unsqueeze(1)  # (batch,1,embed_dim)
        K = self.key_proj(support_prototypes).unsqueeze(0).expand(batch_size, -1, -1)  # (batch,n_way,embed_dim)
        V = self.value_proj(support_prototypes).unsqueeze(0).expand(batch_size, -1, -1)

        attn_scores = torch.bmm(Q, K.transpose(1, 2)) * self.scale
        attn_weights = F.softmax(attn_scores, dim=-1)
        refined_query_prototype = torch.bmm(attn_weights, V).squeeze(1)
        return refined_query_prototype, attn_weights

# --- Complete Attention Prototypical Network ---
class AttentionProtoNet(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder
        self.self_attention = SelfAttention(embed_dim=64)
        self.cross_attention = CrossAttention(embed_dim=64)

    def forward(self, support, query):
        # support: (n_way, n_shot, channels, time)
        # query: (n_query, channels, time)
        n_way, n_shot, C, T = support.shape
        n_query = query.shape[0]

        support = support.view(n_way * n_shot, C, T)
        support_embeds = self.encoder(support).view(n_way, n_shot, -1)  # (n_way,n_shot,embed_dim)
        query_embeds = self.encoder(query)  # (n_query, embed_dim)

        # Self-attention over support to get refined prototypes
        refined_prototypes, _ = self.self_attention(support_embeds)  # (n_way, embed_dim)

        # Cross-attention: query-conditioned prototype
        query_conditioned_prototypes, _ = self.cross_attention(query_embeds, refined_prototypes)  # (n_query, embed_dim)

        # Calculate distances between query embeddings and their respective prototypes
        dists = torch.norm(query_embeds - query_conditioned_prototypes, dim=1)  # (n_query)

        # We need logits for cross_entropy: shape (n_query, n_way)
        # Current dists is (n_query), we must expand for classification over classes:
        # We'll compute distances of each query to *all* prototypes, then cross-attend per query.

        # Compute distances of each query to each prototype (for classification)
        all_dists = torch.cdist(query_embeds, refined_prototypes)  # (n_query, n_way)
        log_p_y = F.log_softmax(-all_dists, dim=1)  # probabilities over classes

        return log_p_y

# --- Episodic sampler ---
def create_episode(dataset, n_way, n_shot, n_query):
    min_samples = n_shot + n_query
    label_to_idx = dataset.label_to_idx

    # Filter classes with enough samples
    valid_classes = [k for k in dataset.keywords if dataset.labels.count(label_to_idx[k]) >= min_samples]
    if len(valid_classes) < n_way:
        raise ValueError(f"Not enough classes with enough samples: {len(valid_classes)} found, need {n_way}")

    classes = random.sample(valid_classes, n_way)

    support_samples = []
    query_samples = []
    query_labels = []

    for i, cls in enumerate(classes):
        cls_indices = [idx for idx, label in enumerate(dataset.labels) if label == label_to_idx[cls]]
        selected = random.sample(cls_indices, min_samples)
        support_idxs = selected[:n_shot]
        query_idxs = selected[n_shot:]

        support_samples.append([dataset[i][0] for i in support_idxs])
        query_samples.extend([dataset[i][0] for i in query_idxs])
        query_labels.extend([i] * n_query)

    support = torch.stack([torch.stack(s) for s in support_samples])
    query = torch.stack(query_samples)
    labels = torch.tensor(query_labels)

    return support, query, labels

# --- Training & Evaluation ---
def train_epoch(model, optimizer, dataset, n_way, n_shot, n_query, episodes):
    model.train()
    total_loss = 0.0
    total_acc = 0.0
    for _ in tqdm(range(episodes)):
        support, query, labels = create_episode(dataset, n_way, n_shot, n_query)
        support, query, labels = support.to(device), query.to(device), labels.to(device)

        optimizer.zero_grad()
        log_p_y = model(support, query)
        loss = F.nll_loss(log_p_y, labels)
        loss.backward()
        optimizer.step()

        preds = log_p_y.argmax(dim=1)
        acc = (preds == labels).float().mean()

        total_loss += loss.item()
        total_acc += acc.item()
    return total_loss / episodes, total_acc / episodes

def evaluate(model, dataset, n_way, n_shot, n_query, episodes):
    model.eval()
    total_loss = 0.0
    total_acc = 0.0
    with torch.no_grad():
        for _ in tqdm(range(episodes)):
            support, query, labels = create_episode(dataset, n_way, n_shot, n_query)
            support, query, labels = support.to(device), query.to(device), labels.to(device)

            log_p_y = model(support, query)
            loss = F.nll_loss(log_p_y, labels)

            preds = log_p_y.argmax(dim=1)
            acc = (preds == labels).float().mean()

            total_loss += loss.item()
            total_acc += acc.item()
    return total_loss / episodes, total_acc / episodes

# --- Main Execution ---
transform = MFCCTransform()
train_dataset = SpeechCommandsDataset(CORE_KEYWORDS[:8], root_dir='speech_commands', transform=transform)
val_dataset = SpeechCommandsDataset(CORE_KEYWORDS[8:], root_dir='speech_commands', transform=transform)
test_dataset = SpeechCommandsDataset(UNKNOWN_KEYWORDS[:5], root_dir='speech_commands', transform=transform)

model = AttentionProtoNet(TDResNet7()).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

n_way, n_shot, n_query = 3, 5, 15
train_episodes, val_episodes, test_episodes = 200, 100, 100
epochs = 50

print("Starting initial evaluation on unknown keywords...")
test_loss, test_acc = evaluate(model, test_dataset, n_way, n_shot, n_query, test_episodes)
print(f"Initial Test Loss: {test_loss:.4f} | Initial Test Acc: {test_acc*100:.2f}%")

train_losses, train_accs, val_losses, val_accs = [], [], [], []

# Dynamically adjust val n_way based on available classes in val dataset
val_n_way = len(val_dataset.keywords)
if val_n_way < n_way:
    print(f"Validation n_way ({val_n_way}) is less than train n_way ({n_way}). Using val_n_way={val_n_way} for validation episodes.")

for epoch in range(epochs):
    train_loss, train_acc = train_epoch(model, optimizer, train_dataset, n_way, n_shot, n_query, train_episodes)
    val_loss, val_acc = evaluate(model, val_dataset, val_n_way, n_shot, n_query, val_episodes)

    train_losses.append(train_loss)
    train_accs.append(train_acc)
    val_losses.append(val_loss)
    val_accs.append(val_acc)

    print(f"Epoch {epoch+1}/{epochs}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc*100:.2f}% | Val Loss: {val_loss:.4f}, Val Acc: {val_acc*100:.2f}%")


# Plot training curves
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss over Epochs')

plt.subplot(1,2,2)
plt.plot(train_accs, label='Train Acc')
plt.plot(val_accs, label='Val Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy over Epochs')

plt.show()

Starting initial evaluation on unknown keywords...


100%|██████████| 100/100 [01:13<00:00,  1.37it/s]


Initial Test Loss: 8.6199 | Initial Test Acc: 33.36%
Validation n_way (2) is less than train n_way (3). Using val_n_way=2 for validation episodes.


100%|██████████| 200/200 [02:02<00:00,  1.63it/s]
100%|██████████| 100/100 [00:38<00:00,  2.61it/s]


Epoch 1/50: Train Loss: 0.6226, Train Acc: 71.66% | Val Loss: 0.4254, Val Acc: 81.03%


100%|██████████| 200/200 [02:00<00:00,  1.65it/s]
100%|██████████| 100/100 [00:38<00:00,  2.56it/s]


Epoch 2/50: Train Loss: 0.3457, Train Acc: 85.98% | Val Loss: 0.5186, Val Acc: 74.50%


100%|██████████| 200/200 [02:00<00:00,  1.66it/s]
100%|██████████| 100/100 [00:42<00:00,  2.37it/s]


Epoch 3/50: Train Loss: 0.2170, Train Acc: 92.11% | Val Loss: 0.6021, Val Acc: 70.10%


100%|██████████| 200/200 [02:02<00:00,  1.64it/s]
100%|██████████| 100/100 [00:38<00:00,  2.63it/s]


Epoch 4/50: Train Loss: 0.1762, Train Acc: 93.59% | Val Loss: 0.5572, Val Acc: 73.87%


100%|██████████| 200/200 [01:59<00:00,  1.68it/s]
100%|██████████| 100/100 [00:38<00:00,  2.62it/s]


Epoch 5/50: Train Loss: 0.1782, Train Acc: 93.27% | Val Loss: 0.5715, Val Acc: 71.43%


100%|██████████| 200/200 [02:00<00:00,  1.66it/s]
100%|██████████| 100/100 [00:37<00:00,  2.64it/s]


Epoch 6/50: Train Loss: 0.1418, Train Acc: 94.81% | Val Loss: 0.5269, Val Acc: 75.20%


100%|██████████| 200/200 [01:58<00:00,  1.68it/s]
100%|██████████| 100/100 [00:37<00:00,  2.64it/s]


Epoch 7/50: Train Loss: 0.1164, Train Acc: 95.81% | Val Loss: 0.5081, Val Acc: 75.90%


100%|██████████| 200/200 [02:00<00:00,  1.65it/s]
100%|██████████| 100/100 [00:38<00:00,  2.62it/s]


Epoch 8/50: Train Loss: 0.1001, Train Acc: 96.36% | Val Loss: 0.4965, Val Acc: 76.37%


100%|██████████| 200/200 [02:01<00:00,  1.64it/s]
100%|██████████| 100/100 [00:38<00:00,  2.58it/s]


Epoch 9/50: Train Loss: 0.1015, Train Acc: 96.21% | Val Loss: 0.5329, Val Acc: 74.33%


100%|██████████| 200/200 [02:01<00:00,  1.64it/s]
100%|██████████| 100/100 [00:38<00:00,  2.60it/s]


Epoch 10/50: Train Loss: 0.1000, Train Acc: 96.22% | Val Loss: 0.4782, Val Acc: 78.23%


100%|██████████| 200/200 [02:02<00:00,  1.64it/s]
100%|██████████| 100/100 [00:38<00:00,  2.59it/s]


Epoch 11/50: Train Loss: 0.0973, Train Acc: 96.38% | Val Loss: 0.3881, Val Acc: 82.77%


100%|██████████| 200/200 [01:59<00:00,  1.67it/s]
100%|██████████| 100/100 [00:38<00:00,  2.63it/s]


Epoch 12/50: Train Loss: 0.0756, Train Acc: 97.31% | Val Loss: 0.4347, Val Acc: 81.33%


100%|██████████| 200/200 [02:01<00:00,  1.65it/s]
100%|██████████| 100/100 [00:38<00:00,  2.59it/s]


Epoch 13/50: Train Loss: 0.0800, Train Acc: 97.16% | Val Loss: 0.4894, Val Acc: 77.07%


100%|██████████| 200/200 [01:59<00:00,  1.67it/s]
100%|██████████| 100/100 [00:38<00:00,  2.61it/s]


Epoch 14/50: Train Loss: 0.0731, Train Acc: 97.31% | Val Loss: 0.4335, Val Acc: 79.60%


100%|██████████| 200/200 [02:01<00:00,  1.65it/s]
100%|██████████| 100/100 [00:38<00:00,  2.61it/s]


Epoch 15/50: Train Loss: 0.0785, Train Acc: 97.28% | Val Loss: 0.4194, Val Acc: 80.47%


100%|██████████| 200/200 [02:00<00:00,  1.66it/s]
100%|██████████| 100/100 [00:38<00:00,  2.58it/s]


Epoch 16/50: Train Loss: 0.0677, Train Acc: 97.60% | Val Loss: 0.3812, Val Acc: 83.13%


100%|██████████| 200/200 [02:02<00:00,  1.64it/s]
100%|██████████| 100/100 [00:38<00:00,  2.60it/s]


Epoch 17/50: Train Loss: 0.0627, Train Acc: 97.70% | Val Loss: 0.4611, Val Acc: 78.47%


100%|██████████| 200/200 [01:59<00:00,  1.67it/s]
100%|██████████| 100/100 [00:38<00:00,  2.62it/s]


Epoch 18/50: Train Loss: 0.0676, Train Acc: 97.56% | Val Loss: 0.4841, Val Acc: 78.37%


100%|██████████| 200/200 [02:01<00:00,  1.64it/s]
100%|██████████| 100/100 [00:38<00:00,  2.59it/s]


Epoch 19/50: Train Loss: 0.0601, Train Acc: 97.77% | Val Loss: 0.3994, Val Acc: 82.37%


100%|██████████| 200/200 [02:00<00:00,  1.65it/s]
100%|██████████| 100/100 [00:38<00:00,  2.57it/s]


Epoch 20/50: Train Loss: 0.0628, Train Acc: 97.80% | Val Loss: 0.4434, Val Acc: 81.73%


100%|██████████| 200/200 [02:02<00:00,  1.63it/s]
100%|██████████| 100/100 [00:38<00:00,  2.59it/s]


Epoch 21/50: Train Loss: 0.0527, Train Acc: 98.12% | Val Loss: 0.3461, Val Acc: 85.13%


100%|██████████| 200/200 [02:00<00:00,  1.65it/s]
100%|██████████| 100/100 [00:39<00:00,  2.52it/s]


Epoch 22/50: Train Loss: 0.0545, Train Acc: 97.84% | Val Loss: 0.4451, Val Acc: 81.17%


100%|██████████| 200/200 [02:00<00:00,  1.66it/s]
100%|██████████| 100/100 [00:38<00:00,  2.63it/s]


Epoch 23/50: Train Loss: 0.0579, Train Acc: 97.76% | Val Loss: 0.3620, Val Acc: 84.03%


100%|██████████| 200/200 [02:01<00:00,  1.65it/s]
100%|██████████| 100/100 [00:38<00:00,  2.62it/s]


Epoch 24/50: Train Loss: 0.0535, Train Acc: 97.89% | Val Loss: 0.5872, Val Acc: 72.67%


100%|██████████| 200/200 [02:00<00:00,  1.66it/s]
100%|██████████| 100/100 [00:38<00:00,  2.58it/s]


Epoch 25/50: Train Loss: 0.0621, Train Acc: 97.67% | Val Loss: 0.5126, Val Acc: 77.10%


100%|██████████| 200/200 [02:01<00:00,  1.64it/s]
100%|██████████| 100/100 [00:38<00:00,  2.59it/s]


Epoch 26/50: Train Loss: 0.0560, Train Acc: 97.86% | Val Loss: 0.4906, Val Acc: 77.10%


100%|██████████| 200/200 [02:03<00:00,  1.62it/s]
100%|██████████| 100/100 [00:38<00:00,  2.58it/s]


Epoch 27/50: Train Loss: 0.0539, Train Acc: 97.99% | Val Loss: 0.4902, Val Acc: 77.87%


100%|██████████| 200/200 [02:03<00:00,  1.62it/s]
100%|██████████| 100/100 [00:38<00:00,  2.58it/s]


Epoch 28/50: Train Loss: 0.0441, Train Acc: 98.33% | Val Loss: 0.4383, Val Acc: 80.23%


100%|██████████| 200/200 [02:03<00:00,  1.61it/s]
100%|██████████| 100/100 [00:38<00:00,  2.59it/s]


Epoch 29/50: Train Loss: 0.0473, Train Acc: 98.29% | Val Loss: 0.4769, Val Acc: 78.63%


100%|██████████| 200/200 [01:59<00:00,  1.67it/s]
100%|██████████| 100/100 [00:38<00:00,  2.59it/s]


Epoch 30/50: Train Loss: 0.0576, Train Acc: 97.93% | Val Loss: 0.4684, Val Acc: 79.07%


100%|██████████| 200/200 [02:02<00:00,  1.64it/s]
100%|██████████| 100/100 [00:38<00:00,  2.60it/s]


Epoch 31/50: Train Loss: 0.0497, Train Acc: 98.11% | Val Loss: 0.3867, Val Acc: 84.03%


100%|██████████| 200/200 [02:00<00:00,  1.66it/s]
100%|██████████| 100/100 [00:38<00:00,  2.58it/s]


Epoch 32/50: Train Loss: 0.0481, Train Acc: 98.11% | Val Loss: 0.3872, Val Acc: 83.90%


100%|██████████| 200/200 [02:02<00:00,  1.63it/s]
100%|██████████| 100/100 [00:38<00:00,  2.57it/s]


Epoch 33/50: Train Loss: 0.0462, Train Acc: 98.26% | Val Loss: 0.3608, Val Acc: 84.13%


100%|██████████| 200/200 [02:02<00:00,  1.64it/s]
100%|██████████| 100/100 [00:38<00:00,  2.59it/s]


Epoch 34/50: Train Loss: 0.0505, Train Acc: 98.10% | Val Loss: 0.4478, Val Acc: 79.97%


100%|██████████| 200/200 [02:00<00:00,  1.66it/s]
100%|██████████| 100/100 [00:38<00:00,  2.59it/s]


Epoch 35/50: Train Loss: 0.0419, Train Acc: 98.50% | Val Loss: 0.5266, Val Acc: 76.23%


100%|██████████| 200/200 [02:02<00:00,  1.64it/s]
100%|██████████| 100/100 [00:38<00:00,  2.59it/s]


Epoch 36/50: Train Loss: 0.0394, Train Acc: 98.68% | Val Loss: 0.4673, Val Acc: 77.53%


100%|██████████| 200/200 [02:00<00:00,  1.66it/s]
100%|██████████| 100/100 [00:38<00:00,  2.60it/s]


Epoch 37/50: Train Loss: 0.0353, Train Acc: 98.73% | Val Loss: 0.4751, Val Acc: 77.33%


100%|██████████| 200/200 [02:01<00:00,  1.65it/s]
100%|██████████| 100/100 [00:38<00:00,  2.61it/s]


Epoch 38/50: Train Loss: 0.0482, Train Acc: 98.33% | Val Loss: 0.5575, Val Acc: 75.60%


 87%|████████▋ | 174/200 [01:44<00:14,  1.79it/s]