**BEST Statistics:**

**Classification Report:**

              precision    recall  f1-score   support

           0       0.54      0.63      0.58      6000
           1       0.74      0.71      0.72      6000
           2       0.77      0.77      0.77      6000
           3       0.60      0.44      0.51      6000
           4       0.84      0.86      0.85      6000
           5       0.88      0.86      0.87      6000
           6       0.63      0.48      0.54      6000
           7       0.65      0.73      0.69      6000
           8       0.68      0.79      0.73      6000
           9       0.72      0.79      0.76      6000

**total dataset size** = 60000

**accuracy** = 0.71    
 
**macro avg**  
**precision** = 0.70      
**recall** = 0.71      
**f1-score** = 0.70

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import numpy as np

import torch
import torch.nn as nn
from torch import fft
import torch
import numpy as np
# from spikingjelly.clock_driven.neuron import MultiStepLIFNode
# from spikingjelly.clock_driven import functional
from scipy.signal import cont2discrete

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_model_size(model):
    total_params = sum(p.numel() for p in model.parameters())
    param_size = 4  # Assuming float32 (4 bytes per parameter)
    total_size = total_params * param_size
    return total_size / (1024 ** 2)  # Convert bytes to MB

def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

class DropoutNd(nn.Module):
    def __init__(self, p: float = 0.5, tie=True, transposed=True):
        """
        tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
        """
        super().__init__()
        if p < 0 or p >= 1:
            raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p))
        self.p = p
        self.tie = tie
        self.transposed = transposed
        self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p)

    def forward(self, X):
        """X: (batch, dim, lengths...)."""
        if self.training:
            if not self.transposed: X = rearrange(X, 'b ... d -> b d ...')
            # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow because of CPU -> GPU copying
            mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape
            # mask = self.binomial.sample(mask_shape)
            mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p
            X = X * mask * (1.0/(1-self.p))
            if not self.transposed: X = rearrange(X, 'b d ... -> b ... d')
            return X
        return X

class S4DKernel(nn.Module):
    """Generate convolution kernel from diagonal SSM parameters."""
    def __init__(self, d_model, N=64, dt_min=0.001, dt_max=0.1, lr=None):
        super().__init__()
        # Generate dt
        # print(N, dt_max, dt_min)
        H = d_model
        log_dt = torch.rand(H) * ( math.log(dt_max) - math.log(dt_min) ) + math.log(dt_min)

        C = torch.randn(H, N // 2, dtype=torch.cfloat)
        self.C = nn.Parameter(torch.view_as_real(C))
        self.register("log_dt", log_dt, lr)

        log_A_real = torch.log(0.5 * torch.ones(H, N//2))
        A_imag = math.pi * repeat(torch.arange(N//2), 'n -> h n', h=H)
        self.register("log_A_real", log_A_real, lr)
        self.register("A_imag", A_imag, lr)

    def forward(self, L):
        """
        returns: (..., c, L) where c is number of channels (default 1)
        """

        # Materialize parameters
        dt = torch.exp(self.log_dt) # (H)
        C = torch.view_as_complex(self.C) # (H N)
        A = -torch.exp(self.log_A_real) + 1j * self.A_imag # (H N)

        # Vandermonde multiplication
        dtA = A * dt.unsqueeze(-1)  # (H N)
        K = dtA.unsqueeze(-1) * torch.arange(L, device=A.device) # (H N L)
        C = C * (torch.exp(dtA)-1.) / A
        K = 2 * torch.einsum('hn, hnl -> hl', C, torch.exp(K)).real

        return K

    def register(self, name, tensor, lr=None):
        """Register a tensor with a configurable learning rate and 0 weight decay"""

        if lr == 0.0:
            self.register_buffer(name, tensor)
        else:
            self.register_parameter(name, nn.Parameter(tensor))

            optim = {"weight_decay": 0.0}
            if lr is not None: optim["lr"] = lr
            setattr(getattr(self, name), "_optim", optim)

class S4D(nn.Module):
    def __init__(self, d_model, d_state=64, dropout=0.0, transposed=True):
        super().__init__()

        self.h = d_model
        self.n = d_state
        self.d_output = self.h
        self.transposed = transposed

        self.D = nn.Parameter(torch.randn(self.h))

        # SSM Kernel
        self.kernel = S4DKernel(self.h, N=self.n)

        # Pointwise
        self.activation = nn.GELU()
        # dropout_fn = nn.Dropout2d # NOTE: bugged in PyTorch 1.11
        dropout_fn = DropoutNd
        self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()

        # position-wise output transform to mix features
        self.output_linear = nn.Sequential(
            nn.Conv1d(self.h, 2*self.h, kernel_size=1),
            nn.GLU(dim=-2),
        )

    def forward(self, u, **kwargs): # absorbs return_output and transformer src mask
        """ Input and output shape (B, H, L) """
        if not self.transposed: u = u.transpose(-1, -2)
        L = u.size(-1)

        # Compute SSM Kernel
        k = self.kernel(L=L) # (H L)

        # Convolution
        k_f = torch.fft.rfft(k, n=2*L) # (H L)
        u_f = torch.fft.rfft(u, n=2*L) # (B H L)
        y = torch.fft.irfft(u_f*k_f, n=2*L)[..., :L] # (B H L)

        # Compute D term in state space equation - essentially a skip connection
        y = y + u * self.D.unsqueeze(-1)

        y = self.dropout(self.activation(y))
        y = self.output_linear(y)
        if not self.transposed: y = y.transpose(-1, -2)
        return y, None # Return a dummy state to satisfy this repo's interface, but this can be modified

def get_act(act_type = 'spike', **act_params):
    '''
    act_type :- spike, gelu, relu, identity

    output :- class <act_type>
    '''
    act_type = act_type.lower()
    # if act_type == 'spike':
    #     return MultiStepLIFNode(**act_params, backend='cupy')
    #     # act_params['init_tau'] = act_params.pop('tau')
    #     # return MultiStepParametricLIFNode(**act_params, backend="cupy")
    if act_type == 'relu':
        return nn.ReLU()
    elif act_type == 'gelu':
        return nn.GELU()
    elif act_type == 'identity':
        return nn.Identity()

class S4(nn.Module):
    def __init__(self, dim, T, use_all_h=True):
        super().__init__()
        self.dim = dim
        self.use_all_h = use_all_h
        self.S4 = S4D(d_model= dim, dropout=0.1, transposed=False)
        self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1)
        self.proj_bn = nn.BatchNorm1d(dim)
        self.proj_lif = get_act('relu', tau=2.0, detach_reset=True)

    def forward(self, x):
        x = x.transpose(-1,-2).contiguous()
        h, h_n = self.S4(x) # B, N, C; B, C, 1

        x = h.transpose(-1,-2).contiguous()

        x = self.proj_conv(x)
        x = self.proj_lif(self.proj_bn(x).permute(2,1,0).contiguous()).permute(2,1,0).contiguous()
        return x

class LinearFFN(nn.Module):
    def __init__(self, in_features, pre_norm=False, hidden_features=None, out_features=None, drop=0., act_type='spike'):
        super().__init__()
        self.pre_norm = pre_norm
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1_linear  = nn.Linear(in_features, hidden_features)
        self.fc1_ln = nn.LayerNorm(hidden_features)
        self.fc1_lif = get_act(act_type if act_type == 'spike' else 'gelu', tau=2.0, detach_reset=True)

        self.fc2_linear = nn.Linear(hidden_features, out_features)
        self.fc2_ln = nn.LayerNorm(out_features)
        self.fc2_lif = get_act(act_type, tau=2.0, detach_reset=True)
 
        self.c_hidden = hidden_features
        self.c_output = out_features

    def forward(self, x):
        B,C,N = x.shape
        # 
        x = x.permute(0,2,1) # B, N, C
        # x = x.reshape(B*N, C)
        if self.pre_norm:
            x = self.fc1_ln(x)
            x = self.fc1_lif(x)
            x = self.fc1_linear(x)
            
            x = self.fc2_ln(x)
            x = self.fc2_lif(x)
            x = self.fc2_linear(x)

        else:
            x = self.fc1_linear(x)
            x = self.fc1_ln(x)
            x = self.fc1_lif(x)

            x = self.fc2_linear(x)
            x = self.fc2_ln(x)
            x = self.fc2_lif(x)

        # x = x.reshape(B, N, self.c_output)
        x = x.permute(0,2,1) # B, C, N
        return x
    
class Block(nn.Module):
    def __init__(self, dim, T, mlp_ratio=4., act_type='spike'):
        super().__init__()

        self.attn = S4(dim=dim, T=T)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = LinearFFN(in_features=dim, hidden_features=mlp_hidden_dim, act_type=act_type)

    def forward(self, x):
        x = x + self.attn(x)
        x = x + self.mlp(x)
        return x

class perm(nn.Module):
    def __init__(self, a, b, c) -> None:
        super().__init__()
        self.a = a
        self.b = b
        self.c = c

    def forward(self, x):
        return x.permute(self.a,self.b,self.c).contiguous()

def get_conv_block(T, dim, act_type, kernel_size=3, padding=1, groups=1):
    return [
        perm(0,2,1),
        nn.Conv1d(dim, dim, kernel_size=kernel_size, stride=1, padding=padding, groups=groups, bias=False),
        nn.BatchNorm1d(dim),
        perm(1,2,0),
        get_act(act_type, tau=2.0, detach_reset=True),
        perm(2,1,0)
]

class Conv1d4EB(nn.Module):
    def __init__(self, T=128, vw_dim=256, act_type='spike'):
        super().__init__()

        kernel_size = 3
        padding = 1
        groups = 1
        self.proj_conv = nn.ModuleList(
            [perm(0,2,1)]+\
            get_conv_block(T, vw_dim, act_type)+\
            get_conv_block(T, vw_dim, act_type, kernel_size=kernel_size, padding=padding, groups=groups)+\
            get_conv_block(T, vw_dim, act_type, kernel_size=kernel_size, padding=padding, groups=groups)+\
            get_conv_block(T, vw_dim, act_type, kernel_size=kernel_size, padding=padding, groups=groups)+\
            [perm(0,2,1)]
        )
        self.rpe_conv = nn.ModuleList(
            [perm(0,2,1)]+\
            get_conv_block(T, vw_dim, act_type, kernel_size=kernel_size, padding=padding, groups=groups)+\
            [perm(0,2,1)]
        )
        self.act_loss = 0.0
        
    def forward(self, x):

        for ele in self.proj_conv:
            x = ele(x)

        x_rpe = x.clone()
        for ele in self.rpe_conv:
            x_rpe = ele(x_rpe)

        x = x + x_rpe
        
        return x 

from transformers import BertModel , BertTokenizer

class S4_sq_Classifier(nn.Module):
    def __init__(self, input_size, num_layers, hidden_size, num_classes, act_type='relu', T=784, test_mode='all_seq',with_head_lif=False):
        super().__init__()
        self.with_head_lif = with_head_lif
        self.test_mode = test_mode

        bert_model = BertModel.from_pretrained('bert-base-uncased')
        self.embedding = nn.Embedding.from_pretrained(bert_model.embeddings.word_embeddings.weight,freeze=False)

        self.in_layer = nn.Linear(768, hidden_size)

        self.patch_embed = Conv1d4EB(T=T, vw_dim=hidden_size, act_type=act_type)

        self.block = nn.ModuleList([
            Block(dim=hidden_size, T=T, act_type=act_type)
            for j in range(num_layers)
        ])

        # classification head
        if self.with_head_lif:
            self.head_bn = nn.BatchNorm1d(hidden_size)
            self.head_lif = get_act(act_type, tau=2.0, detach_reset=True)

        self.head = nn.Linear(hidden_size, num_classes)
        self.loss = nn.CrossEntropyLoss()

    def forward_features(self, x):
        x = self.patch_embed(x)
        for blk in self.block:
            x = blk(x)
        return x

    def forward(self, x, labels=None, infer=False):
        self.act_loss = 0.0
        x = self.embedding(x)
        x = self.in_layer(x)
        x = x.permute(0, 2, 1).contiguous()
        x = self.forward_features(x)    # b, d, t -> b, d, t

        if self.with_head_lif:
            x = self.head_bn(x)         # b, d, t 
            x = self.head_lif(x)        # b, d, t

        x = x.permute(0, 2, 1).contiguous()
        x = torch.mean(x, 1)
        out = self.head(x)
        if infer:
            return out
        
        return self.loss(out, labels)
    

  from .autonotebook import tqdm as notebook_tqdm
2024-12-01 10:01:39.971502: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-01 10:01:39.985741: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1733047300.000331 1353697 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1733047300.004704 1353697 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-01 10:01:40.019463: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorF

In [2]:

def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

# Example usage
model = S4_sq_Classifier(num_classes=10, input_size=512, num_layers=2, hidden_size=128, T=256).to(device)
total_params, trainable_params = count_parameters(model)
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")


Total parameters: 24186506
Trainable parameters: 24186506


In [2]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import classification_report
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

class YahooDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=256):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Combine text fields and clean up
        self.texts = (dataframe['question_title'].fillna('') + ' [SEP] ' + 
                     dataframe['question_content'].fillna('') + ' [SEP] ' + 
                     dataframe['best_answer'].fillna(''))
        
        # Convert to zero-based indexing
        self.labels = dataframe['class'] - 1

    def __getitem__(self, idx):
        text = str(self.texts.iloc[idx])
        label = self.labels.iloc[idx]
        
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }
    
    def __len__(self):
        return len(self.texts)

def train_epoch(model, dataloader, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    progress_bar = tqdm(dataloader, desc='Training')
    
    for batch in progress_bar:
        optimizer.zero_grad()
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        # print(input_ids.shape)
        # print(labels.shape)
        # print(attention_mask.shape)
        loss = model(
            input_ids,
            labels=labels
        )
        # print(loss)
        # assert False,''
                
        total_loss += loss.item()
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / len(dataloader)

def evaluate(model, dataloader, device):
    model.eval()
    true_labels = []
    predictions = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Evaluating'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(
                input_ids,
                infer=True
            )
            
            preds = torch.argmax(outputs, dim=1)
            true_labels.extend(labels.cpu().numpy())
            predictions.extend(preds.cpu().numpy())
    
    return classification_report(true_labels, predictions, zero_division=0)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Load datasets
print('Loading datasets...')
train_df = pd.read_csv('train.csv', 
                        names=['class', 'question_title', 'question_content', 'best_answer'])

# Sample 10,000 examples per class for balanced training
samples_per_class = 10000
sampled_train_df = []

for class_idx in range(1, 11):  # 10 classes
    class_data = train_df[train_df['class'] == class_idx]
    sampled_class = class_data.sample(n=min(samples_per_class, len(class_data)), 
                                    random_state=42)
    sampled_train_df.append(sampled_class)

train_df = pd.concat(sampled_train_df, ignore_index=True)
print(f'Training with {len(train_df)} examples')

# Load test data
test_df = pd.read_csv('test.csv', 
                        names=['class', 'question_title', 'question_content', 'best_answer'])

# Load classes
with open('classes.txt', 'r') as f:
    class_names = [line.strip() for line in f.readlines()]

# Initialize tokenizer
print('Loading tokenizer...')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

model = S4_sq_Classifier(num_classes=10, input_size=512, num_layers=2, hidden_size=128, T=256).to(device)

# Create datasets
train_dataset = YahooDataset(train_df, tokenizer)
test_dataset = YahooDataset(test_df, tokenizer)

# Create dataloaders
batch_size = 32  # Can use larger batch size with smaller model
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    sampler=RandomSampler(train_dataset)
)
test_loader = DataLoader(test_dataset, batch_size=batch_size)


Using device: cuda
Loading datasets...
Training with 100000 examples
Loading tokenizer...


In [None]:

# Training settings
epochs = 30
optimizer = AdamW(model.parameters(), lr=5e-5)  # Slightly higher learning rate
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

# Training loop
print('Starting training...')
best_accuracy = 0

for epoch in range(epochs):
    print(f'\nEpoch {epoch + 1}/{epochs}')
    
    train_loss = train_epoch(model, train_loader, optimizer, scheduler, device)
    print(f'Average training loss: {train_loss:.4f}')
    
    print('\nEvaluating...')
    report = evaluate(model, test_loader, device)
    print('\nClassification Report:')
    print(report)
    
    # Save model if it improves
    accuracy = float(report.split('\n')[-2].split()[-2])
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), 'best_yahoo_S4_2layer.pt')
        print(f'Saved best model with accuracy: {accuracy:.4f}')


Starting training...

Epoch 1/30


Training: 100%|██████████| 3125/3125 [03:23<00:00, 15.33it/s, loss=1.0156]


Average training loss: 1.2083

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:08<00:00, 27.20it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.59      0.55      0.57      6000
           1       0.78      0.63      0.70      6000
           2       0.75      0.79      0.77      6000
           3       0.49      0.55      0.52      6000
           4       0.82      0.87      0.84      6000
           5       0.87      0.86      0.87      6000
           6       0.58      0.48      0.53      6000
           7       0.70      0.69      0.69      6000
           8       0.70      0.77      0.73      6000
           9       0.72      0.79      0.75      6000

    accuracy                           0.70     60000
   macro avg       0.70      0.70      0.70     60000
weighted avg       0.70      0.70      0.70     60000

Saved best model with accuracy: 0.7000

Epoch 2/30


Training: 100%|██████████| 3125/3125 [03:21<00:00, 15.54it/s, loss=1.2821]


Average training loss: 0.8741

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:10<00:00, 26.49it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.54      0.63      0.58      6000
           1       0.74      0.71      0.72      6000
           2       0.77      0.77      0.77      6000
           3       0.60      0.44      0.51      6000
           4       0.84      0.86      0.85      6000
           5       0.88      0.86      0.87      6000
           6       0.63      0.48      0.54      6000
           7       0.65      0.73      0.69      6000
           8       0.68      0.79      0.73      6000
           9       0.72      0.79      0.76      6000

    accuracy                           0.71     60000
   macro avg       0.70      0.71      0.70     60000
weighted avg       0.70      0.71      0.70     60000


Epoch 3/30


Training: 100%|██████████| 3125/3125 [02:57<00:00, 17.65it/s, loss=0.6946]


Average training loss: 0.7403

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:10<00:00, 26.75it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.55      0.57      0.56      6000
           1       0.74      0.62      0.67      6000
           2       0.73      0.79      0.76      6000
           3       0.45      0.58      0.51      6000
           4       0.85      0.83      0.84      6000
           5       0.87      0.86      0.87      6000
           6       0.56      0.48      0.51      6000
           7       0.76      0.63      0.69      6000
           8       0.75      0.73      0.74      6000
           9       0.71      0.79      0.75      6000

    accuracy                           0.69     60000
   macro avg       0.70      0.69      0.69     60000
weighted avg       0.70      0.69      0.69     60000


Epoch 4/30


Training: 100%|██████████| 3125/3125 [03:14<00:00, 16.05it/s, loss=0.4865]


Average training loss: 0.6162

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:09<00:00, 26.88it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.58      0.55      0.57      6000
           1       0.68      0.72      0.70      6000
           2       0.72      0.79      0.75      6000
           3       0.61      0.39      0.48      6000
           4       0.80      0.87      0.83      6000
           5       0.88      0.84      0.86      6000
           6       0.54      0.48      0.51      6000
           7       0.64      0.72      0.68      6000
           8       0.68      0.78      0.73      6000
           9       0.73      0.75      0.74      6000

    accuracy                           0.69     60000
   macro avg       0.69      0.69      0.68     60000
weighted avg       0.69      0.69      0.68     60000


Epoch 5/30


Training: 100%|██████████| 3125/3125 [03:01<00:00, 17.22it/s, loss=0.6127]


Average training loss: 0.4867

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:08<00:00, 27.23it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.51      0.57      0.54      6000
           1       0.72      0.62      0.66      6000
           2       0.77      0.73      0.75      6000
           3       0.49      0.48      0.49      6000
           4       0.81      0.85      0.83      6000
           5       0.87      0.84      0.86      6000
           6       0.55      0.44      0.49      6000
           7       0.61      0.70      0.65      6000
           8       0.69      0.75      0.72      6000
           9       0.70      0.75      0.72      6000

    accuracy                           0.67     60000
   macro avg       0.67      0.67      0.67     60000
weighted avg       0.67      0.67      0.67     60000


Epoch 6/30


Training: 100%|██████████| 3125/3125 [03:00<00:00, 17.34it/s, loss=0.1810]


Average training loss: 0.3631

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:10<00:00, 26.55it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.50      0.55      0.53      6000
           1       0.67      0.65      0.66      6000
           2       0.70      0.78      0.74      6000
           3       0.50      0.44      0.47      6000
           4       0.79      0.85      0.82      6000
           5       0.82      0.86      0.84      6000
           6       0.54      0.43      0.48      6000
           7       0.67      0.64      0.65      6000
           8       0.70      0.71      0.70      6000
           9       0.70      0.73      0.71      6000

    accuracy                           0.66     60000
   macro avg       0.66      0.66      0.66     60000
weighted avg       0.66      0.66      0.66     60000


Epoch 7/30


Training:  91%|█████████ | 2843/3125 [02:50<00:16, 16.77it/s, loss=0.2998]