**BEST Statistics:**

**Classification Report:**

              precision    recall  f1-score   support

           0       0.56      0.60      0.58      6000
           1       0.75      0.64      0.69      6000
           2       0.75      0.77      0.76      6000
           3       0.49      0.57      0.52      6000
           4       0.87      0.81      0.84      6000
           5       0.90      0.84      0.87      6000
           6       0.55      0.49      0.52      6000
           7       0.68      0.70      0.69      6000
           8       0.67      0.79      0.72      6000
           9       0.78      0.70      0.74      6000
    
**total dataset size** = 60000

**accuracy** = 0.69    
 
**macro avg**  
**precision** = 0.70      
**recall** = 0.69      
**f1-score** = 0.69

In [None]:
import torch
import torch.nn as nn
from torch import fft
import torch
import numpy as np
# from spikingjelly.clock_driven.neuron import MultiStepLIFNode
# from spikingjelly.clock_driven import functional
from scipy.signal import cont2discrete

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_model_size(model):
    total_params = sum(p.numel() for p in model.parameters())
    param_size = 4  # Assuming float32 (4 bytes per parameter)
    total_size = total_params * param_size
    return total_size / (1024 ** 2)  # Convert bytes to MB

def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

class LMUFFTCell(nn.Module):
    def __init__(self, input_size, hidden_size, memory_size, seq_len, theta):
        super(LMUFFTCell, self).__init__()

        self.hidden_size = hidden_size
        self.memory_size = memory_size
        self.seq_len = seq_len
        self.theta = theta

        self.W_u = nn.Linear(in_features = input_size, out_features = 1)
        self.f_u = nn.ReLU()
        self.W_h = nn.Linear(in_features = memory_size + input_size, out_features = hidden_size)
        self.f_h = nn.ReLU()

        A, B = self.stateSpaceMatrices()
        self.register_buffer("A", A) # [memory_size, memory_size]
        self.register_buffer("B", B) # [memory_size, 1]

        H, fft_H = self.impulse()
        self.register_buffer("H", H) # [memory_size, seq_len]
        self.register_buffer("fft_H", fft_H) # [memory_size, seq_len + 1]

    def stateSpaceMatrices(self):
        """ Returns the discretized state space matrices A and B """

        Q = np.arange(self.memory_size, dtype = np.float64).reshape(-1, 1)
        R = (2*Q + 1) / self.theta
        i, j = np.meshgrid(Q, Q, indexing = "ij")

        # Continuous
        A = R * np.where(i < j, -1, (-1.0)**(i - j + 1))
        B = R * ((-1.0)**Q)
        C = np.ones((1, self.memory_size))
        D = np.zeros((1,))

        # Convert to discrete
        A, B, C, D, dt = cont2discrete(
            system = (A, B, C, D), 
            dt = 1.0, 
            method = "zoh"
        )

        # To torch.tensor
        A = torch.from_numpy(A).float() # [memory_size, memory_size]
        B = torch.from_numpy(B).float() # [memory_size, 1]
        
        return A, B

    def impulse(self):
        """ Returns the matrices H and the 1D Fourier transform of H (Equations 23, 26 of the paper) """

        H = []
        A_i = torch.eye(self.memory_size).to(self.A.device) 
        for t in range(self.seq_len):
            H.append(A_i @ self.B)
            A_i = self.A @ A_i

        H = torch.cat(H, dim = -1) # [memory_size, seq_len]
        fft_H = fft.rfft(H, n = 2*self.seq_len, dim = -1) # [memory_size, seq_len + 1]

        return H, fft_H

    def forward(self, x):
        """
        Parameters:
            x (torch.tensor): 
                Input of size [batch_size, seq_len, input_size]
        """
        batch_size, seq_len, input_size = x.shape
        # print("batch_size, seq_len, input_size", batch_size, seq_len, input_size)

        # Equation 18 of the paper
        u = self.f_u(self.W_u(x)) # [batch_size, seq_len, 1]

        # Equation 26 of the paper
        fft_input = u.permute(0, 2, 1) # [batch_size, 1, seq_len]
        fft_u = fft.rfft(fft_input, n = 2*seq_len, dim = -1) # [batch_size, seq_len, seq_len+1]

        # Element-wise multiplication (uses broadcasting)
        # [batch_size, 1, seq_len+1] * [1, memory_size, seq_len+1]
        temp = fft_u * self.fft_H.unsqueeze(0) # [batch_size, memory_size, seq_len+1]

        m = fft.irfft(temp, n = 2*seq_len, dim = -1) # [batch_size, memory_size, seq_len+1]
        m = m[:, :, :seq_len] # [batch_size, memory_size, seq_len]
        m = m.permute(0, 2, 1) # [batch_size, seq_len, memory_size]

        # Equation 20 of the paper (W_m@m + W_x@x  W@[m;x])
        input_h = torch.cat((m, x), dim = -1) # [batch_size, seq_len, memory_size + input_size]
        h = self.f_h(self.W_h(input_h)) # [batch_size, seq_len, hidden_size]

        h_n = h[:, -1, :] # [batch_size*T, hidden_size]

        return h, h_n
    
    def forward_recurrent(self, x, m_last):
        u = self.f_u(self.W_u(x)) # [batch_size, seq_len, 1]
        # A: torch.Size([512, 512]), m_last: torch.Size([256, 512]), B: torch.Size([512, 1]), u: torch.Size([256, 1])
        m = m_last @ self.A.T + u @ self.B.T  # [batch_size, memory_size]
        input_h = torch.cat((m, x), dim = -1) # [batch_size, seq_len, memory_size + input_size]
        h = self.f_h(self.W_h(input_h)) # [batch_size, seq_len, hidden_size]

        return h, m

class LMU(nn.Module):
    def __init__(self, dim, T, use_all_h=True):
        super().__init__()
        self.dim = dim
        self.hidden_size = dim
        self.memory_size = dim
        self.use_all_h = use_all_h
        self.lmu = LMUFFTCell(input_size=dim, hidden_size=self.hidden_size, memory_size=self.memory_size, seq_len=T, theta=T)
        # self.lmu = LMUFFTCell(input_size=dim, hidden_size=self.hidden_size, memory_size=self.memory_size, seq_len=64, theta=64)

        self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1)
        self.proj_bn = nn.BatchNorm1d(dim)

    def forward(self, x):
        x = x.transpose(-1,-2).contiguous() # B, C, N -> B, N, C
        h, _ = self.lmu(x) # B, N, C; B, C
        
        x = h.transpose(-1,-2).contiguous() #if self.use_all_h else h_n.unsqueeze(-1) # h or h_n

        x = self.proj_conv(x)
        x = self.proj_bn(x)

        return x

class LinearFFN(nn.Module):
    def __init__(self, in_features, pre_norm=False, hidden_features=None, out_features=None, drop=0., act_type='spike'):
        super().__init__()
        self.pre_norm = pre_norm
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1_linear  = nn.Linear(in_features, hidden_features)
        self.fc1_ln = nn.LayerNorm(hidden_features)
        self.fc1_lif = get_act(act_type if act_type == 'spike' else 'gelu', tau=2.0, detach_reset=True)

        self.fc2_linear = nn.Linear(hidden_features, out_features)
        self.fc2_ln = nn.LayerNorm(out_features)
        self.fc2_lif = get_act(act_type, tau=2.0, detach_reset=True)
 
        self.c_hidden = hidden_features
        self.c_output = out_features

    def forward(self, x):
        B,C,N = x.shape
        # 
        x = x.permute(0,2,1) # B, N, C
        # x = x.reshape(B*N, C)
        if self.pre_norm:
            x = self.fc1_ln(x)
            x = self.fc1_lif(x)
            x = self.fc1_linear(x)
            
            x = self.fc2_ln(x)
            x = self.fc2_lif(x)
            x = self.fc2_linear(x)

        else:
            x = self.fc1_linear(x)
            x = self.fc1_ln(x)
            x = self.fc1_lif(x)

            x = self.fc2_linear(x)
            x = self.fc2_ln(x)
            x = self.fc2_lif(x)

        # x = x.reshape(B, N, self.c_output)
        x = x.permute(0,2,1) # B, C, N
        return x
    
class Block(nn.Module):
    def __init__(self, dim, T, mlp_ratio=4., act_type='spike'):
        super().__init__()

        self.attn = LMU(dim=dim, T=T)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = LinearFFN(in_features=dim, hidden_features=mlp_hidden_dim, act_type=act_type)

    def forward(self, x):
        x = x + self.attn(x)
        x = x + self.mlp(x)
        return x

class perm(nn.Module):
    def __init__(self, a, b, c) -> None:
        super().__init__()
        self.a = a
        self.b = b
        self.c = c

    def forward(self, x):
        return x.permute(self.a,self.b,self.c).contiguous()

def get_act(act_type = 'spike', **act_params):
    '''
    act_type :- spike, gelu, relu, identity

    output :- class <act_type>
    '''
    act_type = act_type.lower()
    # if act_type == 'spike':
    #     return MultiStepLIFNode(**act_params, backend='cupy')
    #     # act_params['init_tau'] = act_params.pop('tau')
    #     # return MultiStepParametricLIFNode(**act_params, backend="cupy")
    if act_type == 'relu':
        return nn.ReLU()
    elif act_type == 'gelu':
        return nn.GELU()
    elif act_type == 'identity':
        return nn.Identity()
    
def get_conv_block(T, dim, act_type, kernel_size=3, padding=1, groups=1):
    return [
        perm(0,2,1),
        nn.Conv1d(dim, dim, kernel_size=kernel_size, stride=1, padding=padding, groups=groups, bias=False),
        nn.BatchNorm1d(dim),
        perm(1,2,0),
        get_act(act_type, tau=2.0, detach_reset=True),
        perm(2,1,0)
]

class Conv1d4EB(nn.Module):
    def __init__(self, T=128, vw_dim=256, act_type='spike'):
        super().__init__()

        kernel_size = 3
        padding = 1
        groups = 1
        self.proj_conv = nn.ModuleList(
            [perm(0,2,1)]+\
            get_conv_block(T, vw_dim, act_type)+\
            get_conv_block(T, vw_dim, act_type, kernel_size=kernel_size, padding=padding, groups=groups)+\
            get_conv_block(T, vw_dim, act_type, kernel_size=kernel_size, padding=padding, groups=groups)+\
            get_conv_block(T, vw_dim, act_type, kernel_size=kernel_size, padding=padding, groups=groups)+\
            [perm(0,2,1)]
        )
        self.rpe_conv = nn.ModuleList(
            [perm(0,2,1)]+\
            get_conv_block(T, vw_dim, act_type, kernel_size=kernel_size, padding=padding, groups=groups)+\
            [perm(0,2,1)]
        )
        self.act_loss = 0.0
        
    def forward(self, x):

        for ele in self.proj_conv:
            x = ele(x)

        x_rpe = x.clone()
        for ele in self.rpe_conv:
            x_rpe = ele(x_rpe)

        x = x + x_rpe
        
        return x 

from transformers import BertModel , BertTokenizer

class LMUformer_sq_Classifier(nn.Module):
    def __init__(self, input_size, num_layers, hidden_size, num_classes, act_type='relu', T=784, test_mode='all_seq',with_head_lif=False):
        super().__init__()
        self.with_head_lif = with_head_lif
        self.test_mode = test_mode

        bert_model = BertModel.from_pretrained('bert-base-uncased')
        self.embedding = nn.Embedding.from_pretrained(bert_model.embeddings.word_embeddings.weight,freeze=False)

        self.in_layer = nn.Linear(768, hidden_size)

        self.patch_embed = Conv1d4EB(T=T, vw_dim=hidden_size, act_type=act_type)

        self.block = nn.ModuleList([
            Block(dim=hidden_size, T=T, act_type=act_type)
            for j in range(num_layers)
        ])

        # classification head
        if self.with_head_lif:
            self.head_bn = nn.BatchNorm1d(hidden_size)
            self.head_lif = get_act(act_type, tau=2.0, detach_reset=True)

        self.head = nn.Linear(hidden_size, num_classes)
        self.loss = nn.CrossEntropyLoss()

    def forward_features(self, x):
        x = self.patch_embed(x)
        for blk in self.block:
            x = blk(x)
        return x

    def forward(self, x, labels=None, infer=False):
        self.act_loss = 0.0
        x = self.embedding(x)
        x = self.in_layer(x)
        x = x.permute(0, 2, 1).contiguous()
        x = self.forward_features(x)    # b, d, t -> b, d, t

        if self.with_head_lif:
            x = self.head_bn(x)         # b, d, t 
            x = self.head_lif(x)        # b, d, t

        x = x.permute(0, 2, 1).contiguous()
        x = torch.mean(x, 1)
        out = self.head(x)
        if infer:
            return out
        
        return self.loss(out, labels)
    


  from .autonotebook import tqdm as notebook_tqdm
2024-12-01 09:58:05.649494: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-01 09:58:05.661475: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1733047085.675481 1358246 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1733047085.679699 1358246 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-01 09:58:05.694162: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorF

In [43]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import classification_report
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

class YahooDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=256):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Combine text fields and clean up
        self.texts = (dataframe['question_title'].fillna('') + ' [SEP] ' + 
                     dataframe['question_content'].fillna('') + ' [SEP] ' + 
                     dataframe['best_answer'].fillna(''))
        
        # Convert to zero-based indexing
        self.labels = dataframe['class'] - 1

    def __getitem__(self, idx):
        text = str(self.texts.iloc[idx])
        label = self.labels.iloc[idx]
        
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }
    
    def __len__(self):
        return len(self.texts)

def train_epoch(model, dataloader, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    progress_bar = tqdm(dataloader, desc='Training')
    
    for batch in progress_bar:
        optimizer.zero_grad()
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        # print(input_ids.shape)
        # print(labels.shape)
        # print(attention_mask.shape)
        loss = model(
            input_ids,
            labels=labels
        )
        # print(loss)
        # assert False,''
                
        total_loss += loss.item()
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / len(dataloader)

def evaluate(model, dataloader, device):
    model.eval()
    true_labels = []
    predictions = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Evaluating'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(
                input_ids,
                infer=True
            )
            
            preds = torch.argmax(outputs, dim=1)
            true_labels.extend(labels.cpu().numpy())
            predictions.extend(preds.cpu().numpy())
    
    return classification_report(true_labels, predictions, zero_division=0)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Load datasets
print('Loading datasets...')
train_df = pd.read_csv('train.csv', 
                        names=['class', 'question_title', 'question_content', 'best_answer'])

# Sample 10,000 examples per class for balanced training
samples_per_class = 10000
sampled_train_df = []

for class_idx in range(1, 11):  # 10 classes
    class_data = train_df[train_df['class'] == class_idx]
    sampled_class = class_data.sample(n=min(samples_per_class, len(class_data)), 
                                    random_state=42)
    sampled_train_df.append(sampled_class)

train_df = pd.concat(sampled_train_df, ignore_index=True)
print(f'Training with {len(train_df)} examples')

# Load test data
test_df = pd.read_csv('test.csv', 
                        names=['class', 'question_title', 'question_content', 'best_answer'])

# Load classes
with open('classes.txt', 'r') as f:
    class_names = [line.strip() for line in f.readlines()]

# Initialize tokenizer
print('Loading tokenizer...')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

model = LMUformer_sq_Classifier(num_classes=10, input_size=512, num_layers=2, hidden_size=1024, T=256).to(device)

total_params, trainable_params = count_parameters(model)
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")

size_in_mb = get_model_size(model)
print(f"Model size: {size_in_mb:.2f} MB")

# Create datasets
train_dataset = YahooDataset(train_df, tokenizer)
test_dataset = YahooDataset(test_df, tokenizer)

# Create dataloaders
batch_size = 32  # Can use larger batch size with smaller model
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    sampler=RandomSampler(train_dataset)
)
test_loader = DataLoader(test_dataset, batch_size=batch_size)


Using device: cuda
Loading datasets...
Training with 100000 examples
Loading tokenizer...
Total parameters: 63087116
Trainable parameters: 63087116
Model size: 240.66 MB


In [44]:

# Training settings
epochs = 7
optimizer = AdamW(model.parameters(), lr=5e-5)  # Slightly higher learning rate
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

# Training loop
print('Starting training...')
best_accuracy = 0

for epoch in range(epochs):
    print(f'\nEpoch {epoch + 1}/{epochs}')
    
    train_loss = train_epoch(model, train_loader, optimizer, scheduler, device)
    print(f'Average training loss: {train_loss:.4f}')
    
    print('\nEvaluating...')
    report = evaluate(model, test_loader, device)
    print('\nClassification Report:')
    print(report)
    
    # Save model if it improves
    accuracy = float(report.split('\n')[-2].split()[-2])
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), 'best_yahoo_LMUformer_2layer.pt')
        print(f'Saved best model with accuracy: {accuracy:.4f}')


Starting training...

Epoch 1/7


Training: 100%|██████████| 3125/3125 [13:51<00:00,  3.76it/s, loss=0.9979]


Average training loss: 1.1376

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [03:56<00:00,  7.94it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.62      0.53      0.57      6000
           1       0.66      0.77      0.71      6000
           2       0.75      0.76      0.76      6000
           3       0.56      0.47      0.51      6000
           4       0.84      0.84      0.84      6000
           5       0.88      0.84      0.86      6000
           6       0.56      0.52      0.54      6000
           7       0.58      0.78      0.66      6000
           8       0.75      0.71      0.73      6000
           9       0.77      0.74      0.76      6000

    accuracy                           0.70     60000
   macro avg       0.70      0.70      0.69     60000
weighted avg       0.70      0.70      0.69     60000

Saved best model with accuracy: 0.6900

Epoch 2/7


Training: 100%|██████████| 3125/3125 [13:52<00:00,  3.76it/s, loss=0.4604]


Average training loss: 0.8642

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [03:52<00:00,  8.08it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.59      0.54      0.57      6000
           1       0.75      0.68      0.72      6000
           2       0.72      0.82      0.77      6000
           3       0.55      0.51      0.53      6000
           4       0.85      0.85      0.85      6000
           5       0.86      0.87      0.87      6000
           6       0.66      0.46      0.54      6000
           7       0.65      0.74      0.69      6000
           8       0.65      0.82      0.73      6000
           9       0.77      0.77      0.77      6000

    accuracy                           0.71     60000
   macro avg       0.71      0.71      0.70     60000
weighted avg       0.71      0.71      0.70     60000

Saved best model with accuracy: 0.7000

Epoch 3/7


Training: 100%|██████████| 3125/3125 [12:02<00:00,  4.32it/s, loss=1.1622]


Average training loss: 0.7029

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [02:20<00:00, 13.30it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.52      0.63      0.57      6000
           1       0.68      0.74      0.71      6000
           2       0.69      0.82      0.75      6000
           3       0.62      0.41      0.49      6000
           4       0.87      0.83      0.85      6000
           5       0.82      0.88      0.85      6000
           6       0.65      0.45      0.53      6000
           7       0.67      0.71      0.69      6000
           8       0.72      0.75      0.73      6000
           9       0.72      0.75      0.74      6000

    accuracy                           0.70     60000
   macro avg       0.70      0.70      0.69     60000
weighted avg       0.70      0.70      0.69     60000


Epoch 4/7


Training: 100%|██████████| 3125/3125 [08:10<00:00,  6.37it/s, loss=0.5467]


Average training loss: 0.5359

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [02:22<00:00, 13.14it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.62      0.49      0.54      6000
           1       0.66      0.75      0.70      6000
           2       0.73      0.77      0.75      6000
           3       0.57      0.45      0.50      6000
           4       0.80      0.87      0.84      6000
           5       0.91      0.81      0.86      6000
           6       0.60      0.42      0.50      6000
           7       0.69      0.67      0.68      6000
           8       0.58      0.83      0.69      6000
           9       0.69      0.79      0.74      6000

    accuracy                           0.69     60000
   macro avg       0.68      0.69      0.68     60000
weighted avg       0.68      0.69      0.68     60000


Epoch 5/7


Training: 100%|██████████| 3125/3125 [08:12<00:00,  6.34it/s, loss=0.4225]


Average training loss: 0.3692

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [02:21<00:00, 13.23it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.55      0.54      0.54      6000
           1       0.66      0.70      0.68      6000
           2       0.72      0.77      0.74      6000
           3       0.48      0.49      0.49      6000
           4       0.82      0.85      0.83      6000
           5       0.85      0.85      0.85      6000
           6       0.52      0.47      0.49      6000
           7       0.70      0.65      0.67      6000
           8       0.70      0.71      0.71      6000
           9       0.73      0.69      0.71      6000

    accuracy                           0.67     60000
   macro avg       0.67      0.67      0.67     60000
weighted avg       0.67      0.67      0.67     60000


Epoch 6/7


Training: 100%|██████████| 3125/3125 [08:16<00:00,  6.29it/s, loss=0.1276]


Average training loss: 0.2135

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [02:21<00:00, 13.21it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.56      0.51      0.53      6000
           1       0.71      0.63      0.67      6000
           2       0.72      0.76      0.74      6000
           3       0.47      0.47      0.47      6000
           4       0.82      0.84      0.83      6000
           5       0.87      0.83      0.85      6000
           6       0.43      0.53      0.47      6000
           7       0.69      0.65      0.67      6000
           8       0.67      0.72      0.69      6000
           9       0.75      0.68      0.71      6000

    accuracy                           0.66     60000
   macro avg       0.67      0.66      0.66     60000
weighted avg       0.67      0.66      0.66     60000


Epoch 7/7


Training: 100%|██████████| 3125/3125 [08:08<00:00,  6.40it/s, loss=0.1595]


Average training loss: 0.1096

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [02:19<00:00, 13.43it/s]


Classification Report:
              precision    recall  f1-score   support

           0       0.54      0.51      0.53      6000
           1       0.66      0.68      0.67      6000
           2       0.74      0.74      0.74      6000
           3       0.49      0.46      0.47      6000
           4       0.82      0.84      0.83      6000
           5       0.86      0.84      0.85      6000
           6       0.46      0.49      0.47      6000
           7       0.66      0.67      0.67      6000
           8       0.71      0.69      0.70      6000
           9       0.70      0.73      0.71      6000

    accuracy                           0.66     60000
   macro avg       0.66      0.66      0.66     60000
weighted avg       0.66      0.66      0.66     60000






In [None]:

# Training settings
epochs = 30
optimizer = AdamW(model.parameters(), lr=5e-5)  # Slightly higher learning rate
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

# Training loop
print('Starting training...')
best_accuracy = 0

for epoch in range(epochs):
    print(f'\nEpoch {epoch + 1}/{epochs}')
    
    train_loss = train_epoch(model, train_loader, optimizer, scheduler, device)
    print(f'Average training loss: {train_loss:.4f}')
    
    print('\nEvaluating...')
    report = evaluate(model, test_loader, device)
    print('\nClassification Report:')
    print(report)
    
    # Save model if it improves
    accuracy = float(report.split('\n')[-2].split()[-2])
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), 'best_yahoo_LMUformer_2layer.pt')
        print(f'Saved best model with accuracy: {accuracy:.4f}')


Starting training...

Epoch 1/30


Training: 100%|██████████| 3125/3125 [02:57<00:00, 17.56it/s, loss=1.2577]


Average training loss: 1.2287

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:08<00:00, 27.19it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.64      0.50      0.56      6000
           1       0.65      0.77      0.70      6000
           2       0.78      0.73      0.75      6000
           3       0.55      0.47      0.51      6000
           4       0.80      0.84      0.82      6000
           5       0.79      0.89      0.84      6000
           6       0.53      0.49      0.51      6000
           7       0.64      0.71      0.67      6000
           8       0.69      0.75      0.72      6000
           9       0.76      0.73      0.75      6000

    accuracy                           0.69     60000
   macro avg       0.68      0.69      0.68     60000
weighted avg       0.68      0.69      0.68     60000

Saved best model with accuracy: 0.6800

Epoch 2/30


Training: 100%|██████████| 3125/3125 [03:01<00:00, 17.18it/s, loss=1.1389]


Average training loss: 0.9080

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:09<00:00, 26.87it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.56      0.60      0.58      6000
           1       0.75      0.64      0.69      6000
           2       0.75      0.77      0.76      6000
           3       0.49      0.57      0.52      6000
           4       0.87      0.81      0.84      6000
           5       0.90      0.84      0.87      6000
           6       0.55      0.49      0.52      6000
           7       0.68      0.70      0.69      6000
           8       0.67      0.79      0.72      6000
           9       0.78      0.70      0.74      6000

    accuracy                           0.69     60000
   macro avg       0.70      0.69      0.69     60000
weighted avg       0.70      0.69      0.69     60000

Saved best model with accuracy: 0.6900

Epoch 3/30


Training: 100%|██████████| 3125/3125 [02:54<00:00, 17.88it/s, loss=0.9929]


Average training loss: 0.7628

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:10<00:00, 26.58it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.60      0.52      0.56      6000
           1       0.74      0.66      0.70      6000
           2       0.72      0.79      0.75      6000
           3       0.57      0.46      0.51      6000
           4       0.79      0.88      0.83      6000
           5       0.80      0.89      0.84      6000
           6       0.57      0.47      0.51      6000
           7       0.66      0.69      0.68      6000
           8       0.67      0.78      0.72      6000
           9       0.71      0.77      0.74      6000

    accuracy                           0.69     60000
   macro avg       0.68      0.69      0.68     60000
weighted avg       0.68      0.69      0.68     60000


Epoch 4/30


Training: 100%|██████████| 3125/3125 [02:57<00:00, 17.59it/s, loss=0.7283]


Average training loss: 0.6200

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:09<00:00, 26.79it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.54      0.53      0.53      6000
           1       0.66      0.72      0.69      6000
           2       0.75      0.72      0.73      6000
           3       0.50      0.49      0.50      6000
           4       0.82      0.85      0.83      6000
           5       0.90      0.82      0.86      6000
           6       0.51      0.48      0.50      6000
           7       0.64      0.70      0.67      6000
           8       0.71      0.72      0.72      6000
           9       0.73      0.71      0.72      6000

    accuracy                           0.67     60000
   macro avg       0.68      0.67      0.67     60000
weighted avg       0.68      0.67      0.67     60000


Epoch 5/30


Training: 100%|██████████| 3125/3125 [02:59<00:00, 17.38it/s, loss=0.3395]


Average training loss: 0.4820

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:10<00:00, 26.48it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.51      0.56      0.53      6000
           1       0.71      0.66      0.68      6000
           2       0.67      0.79      0.72      6000
           3       0.56      0.40      0.47      6000
           4       0.81      0.83      0.82      6000
           5       0.85      0.84      0.85      6000
           6       0.52      0.45      0.48      6000
           7       0.64      0.67      0.66      6000
           8       0.66      0.75      0.70      6000
           9       0.72      0.73      0.72      6000

    accuracy                           0.67     60000
   macro avg       0.66      0.67      0.66     60000
weighted avg       0.66      0.67      0.66     60000


Epoch 6/30


Training: 100%|██████████| 3125/3125 [03:04<00:00, 16.94it/s, loss=0.1881]


Average training loss: 0.3551

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:10<00:00, 26.52it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.48      0.57      0.52      6000
           1       0.70      0.62      0.66      6000
           2       0.73      0.72      0.73      6000
           3       0.47      0.46      0.46      6000
           4       0.82      0.81      0.82      6000
           5       0.89      0.80      0.84      6000
           6       0.44      0.50      0.47      6000
           7       0.65      0.65      0.65      6000
           8       0.68      0.72      0.70      6000
           9       0.74      0.68      0.71      6000

    accuracy                           0.65     60000
   macro avg       0.66      0.65      0.66     60000
weighted avg       0.66      0.65      0.66     60000


Epoch 7/30


Training: 100%|██████████| 3125/3125 [03:06<00:00, 16.77it/s, loss=0.5009]


Average training loss: 0.2541

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:09<00:00, 27.00it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.59      0.41      0.49      6000
           1       0.66      0.60      0.63      6000
           2       0.76      0.66      0.71      6000
           3       0.41      0.53      0.46      6000
           4       0.78      0.81      0.80      6000
           5       0.84      0.83      0.84      6000
           6       0.41      0.48      0.45      6000
           7       0.61      0.67      0.63      6000
           8       0.70      0.68      0.69      6000
           9       0.71      0.69      0.70      6000

    accuracy                           0.64     60000
   macro avg       0.65      0.64      0.64     60000
weighted avg       0.65      0.64      0.64     60000


Epoch 8/30


Training: 100%|██████████| 3125/3125 [02:59<00:00, 17.36it/s, loss=0.3773]


Average training loss: 0.1803

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:10<00:00, 26.47it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.53      0.49      0.51      6000
           1       0.66      0.62      0.64      6000
           2       0.74      0.68      0.71      6000
           3       0.43      0.50      0.46      6000
           4       0.78      0.81      0.80      6000
           5       0.86      0.81      0.84      6000
           6       0.45      0.45      0.45      6000
           7       0.62      0.66      0.64      6000
           8       0.68      0.70      0.69      6000
           9       0.72      0.67      0.70      6000

    accuracy                           0.64     60000
   macro avg       0.65      0.64      0.64     60000
weighted avg       0.65      0.64      0.64     60000


Epoch 9/30


Training: 100%|██████████| 3125/3125 [02:52<00:00, 18.13it/s, loss=0.1792]


Average training loss: 0.1271

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:10<00:00, 26.48it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.52      0.48      0.50      6000
           1       0.66      0.64      0.65      6000
           2       0.69      0.73      0.71      6000
           3       0.49      0.42      0.45      6000
           4       0.73      0.85      0.78      6000
           5       0.86      0.80      0.83      6000
           6       0.45      0.41      0.43      6000
           7       0.61      0.64      0.63      6000
           8       0.66      0.71      0.68      6000
           9       0.68      0.71      0.69      6000

    accuracy                           0.64     60000
   macro avg       0.63      0.64      0.64     60000
weighted avg       0.63      0.64      0.64     60000


Epoch 10/30


Training: 100%|██████████| 3125/3125 [02:59<00:00, 17.43it/s, loss=0.1849]


Average training loss: 0.0914

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:09<00:00, 27.08it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.44      0.57      0.50      6000
           1       0.71      0.57      0.63      6000
           2       0.68      0.73      0.70      6000
           3       0.47      0.41      0.44      6000
           4       0.84      0.76      0.80      6000
           5       0.88      0.78      0.83      6000
           6       0.39      0.49      0.43      6000
           7       0.65      0.60      0.62      6000
           8       0.64      0.72      0.68      6000
           9       0.73      0.64      0.68      6000

    accuracy                           0.63     60000
   macro avg       0.64      0.63      0.63     60000
weighted avg       0.64      0.63      0.63     60000


Epoch 11/30


Training: 100%|██████████| 3125/3125 [03:01<00:00, 17.18it/s, loss=0.0340]


Average training loss: 0.0661

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:11<00:00, 26.35it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.48      0.53      0.51      6000
           1       0.70      0.56      0.62      6000
           2       0.74      0.69      0.71      6000
           3       0.47      0.40      0.43      6000
           4       0.80      0.79      0.79      6000
           5       0.90      0.76      0.83      6000
           6       0.41      0.48      0.44      6000
           7       0.63      0.62      0.63      6000
           8       0.64      0.72      0.68      6000
           9       0.63      0.75      0.68      6000

    accuracy                           0.63     60000
   macro avg       0.64      0.63      0.63     60000
weighted avg       0.64      0.63      0.63     60000


Epoch 12/30


Training: 100%|██████████| 3125/3125 [03:02<00:00, 17.16it/s, loss=0.0276]


Average training loss: 0.0505

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:11<00:00, 26.30it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.47      0.53      0.50      6000
           1       0.70      0.57      0.63      6000
           2       0.74      0.70      0.71      6000
           3       0.46      0.43      0.45      6000
           4       0.82      0.78      0.80      6000
           5       0.91      0.72      0.81      6000
           6       0.37      0.50      0.43      6000
           7       0.64      0.61      0.62      6000
           8       0.68      0.66      0.67      6000
           9       0.63      0.74      0.68      6000

    accuracy                           0.62     60000
   macro avg       0.64      0.62      0.63     60000
weighted avg       0.64      0.62      0.63     60000


Epoch 13/30


Training: 100%|██████████| 3125/3125 [02:53<00:00, 17.98it/s, loss=0.0154]


Average training loss: 0.0373

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:09<00:00, 27.03it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.52      0.47      0.49      6000
           1       0.71      0.54      0.62      6000
           2       0.72      0.70      0.71      6000
           3       0.43      0.46      0.44      6000
           4       0.77      0.79      0.78      6000
           5       0.85      0.81      0.82      6000
           6       0.38      0.50      0.43      6000
           7       0.63      0.61      0.62      6000
           8       0.66      0.69      0.67      6000
           9       0.70      0.68      0.69      6000

    accuracy                           0.63     60000
   macro avg       0.64      0.63      0.63     60000
weighted avg       0.64      0.63      0.63     60000


Epoch 14/30


Training: 100%|██████████| 3125/3125 [03:03<00:00, 17.00it/s, loss=0.0068]


Average training loss: 0.0290

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:10<00:00, 26.72it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.51      0.49      0.50      6000
           1       0.65      0.64      0.64      6000
           2       0.68      0.72      0.70      6000
           3       0.45      0.43      0.44      6000
           4       0.78      0.80      0.79      6000
           5       0.85      0.80      0.83      6000
           6       0.43      0.45      0.44      6000
           7       0.64      0.62      0.63      6000
           8       0.69      0.65      0.67      6000
           9       0.64      0.73      0.68      6000

    accuracy                           0.63     60000
   macro avg       0.63      0.63      0.63     60000
weighted avg       0.63      0.63      0.63     60000


Epoch 15/30


Training: 100%|██████████| 3125/3125 [03:03<00:00, 17.07it/s, loss=0.0060]


Average training loss: 0.0214

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:09<00:00, 26.92it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.52      0.46      0.49      6000
           1       0.68      0.61      0.64      6000
           2       0.69      0.72      0.71      6000
           3       0.43      0.46      0.44      6000
           4       0.83      0.75      0.79      6000
           5       0.89      0.77      0.83      6000
           6       0.38      0.49      0.43      6000
           7       0.66      0.60      0.63      6000
           8       0.64      0.70      0.67      6000
           9       0.66      0.72      0.69      6000

    accuracy                           0.63     60000
   macro avg       0.64      0.63      0.63     60000
weighted avg       0.64      0.63      0.63     60000


Epoch 16/30


Training: 100%|██████████| 3125/3125 [03:03<00:00, 17.04it/s, loss=0.0019]


Average training loss: 0.0163

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:09<00:00, 27.15it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.51      0.48      0.50      6000
           1       0.65      0.63      0.64      6000
           2       0.72      0.70      0.71      6000
           3       0.42      0.48      0.45      6000
           4       0.71      0.84      0.77      6000
           5       0.86      0.81      0.83      6000
           6       0.46      0.39      0.43      6000
           7       0.60      0.66      0.63      6000
           8       0.69      0.65      0.67      6000
           9       0.70      0.68      0.69      6000

    accuracy                           0.63     60000
   macro avg       0.63      0.63      0.63     60000
weighted avg       0.63      0.63      0.63     60000


Epoch 17/30


Training: 100%|██████████| 3125/3125 [02:59<00:00, 17.41it/s, loss=0.0015]


Average training loss: 0.0114

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:11<00:00, 26.40it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.47      0.52      0.50      6000
           1       0.67      0.62      0.64      6000
           2       0.69      0.71      0.70      6000
           3       0.50      0.36      0.42      6000
           4       0.79      0.79      0.79      6000
           5       0.83      0.81      0.82      6000
           6       0.41      0.47      0.44      6000
           7       0.58      0.66      0.62      6000
           8       0.68      0.65      0.67      6000
           9       0.69      0.67      0.68      6000

    accuracy                           0.63     60000
   macro avg       0.63      0.63      0.63     60000
weighted avg       0.63      0.63      0.63     60000


Epoch 18/30


Training: 100%|██████████| 3125/3125 [03:03<00:00, 17.04it/s, loss=0.0003]


Average training loss: 0.0089

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:10<00:00, 26.51it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.50      0.48      0.49      6000
           1       0.69      0.60      0.64      6000
           2       0.71      0.70      0.71      6000
           3       0.47      0.43      0.45      6000
           4       0.74      0.83      0.78      6000
           5       0.89      0.77      0.82      6000
           6       0.42      0.44      0.43      6000
           7       0.61      0.62      0.62      6000
           8       0.63      0.72      0.67      6000
           9       0.67      0.71      0.69      6000

    accuracy                           0.63     60000
   macro avg       0.63      0.63      0.63     60000
weighted avg       0.63      0.63      0.63     60000


Epoch 19/30


Training: 100%|██████████| 3125/3125 [03:07<00:00, 16.65it/s, loss=0.0024]


Average training loss: 0.0068

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:10<00:00, 26.42it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.49      0.49      0.49      6000
           1       0.65      0.63      0.64      6000
           2       0.70      0.70      0.70      6000
           3       0.42      0.47      0.44      6000
           4       0.81      0.77      0.79      6000
           5       0.85      0.80      0.83      6000
           6       0.41      0.46      0.43      6000
           7       0.64      0.60      0.62      6000
           8       0.68      0.66      0.67      6000
           9       0.69      0.69      0.69      6000

    accuracy                           0.63     60000
   macro avg       0.64      0.63      0.63     60000
weighted avg       0.64      0.63      0.63     60000


Epoch 20/30


Training: 100%|██████████| 3125/3125 [03:07<00:00, 16.68it/s, loss=0.0001]


Average training loss: 0.0051

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:10<00:00, 26.41it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.50      0.48      0.49      6000
           1       0.65      0.64      0.65      6000
           2       0.64      0.74      0.69      6000
           3       0.49      0.39      0.43      6000
           4       0.77      0.80      0.79      6000
           5       0.85      0.80      0.82      6000
           6       0.45      0.40      0.42      6000
           7       0.63      0.62      0.63      6000
           8       0.62      0.72      0.67      6000
           9       0.65      0.71      0.68      6000

    accuracy                           0.63     60000
   macro avg       0.63      0.63      0.63     60000
weighted avg       0.63      0.63      0.63     60000


Epoch 21/30


Training: 100%|██████████| 3125/3125 [03:08<00:00, 16.61it/s, loss=0.0003]


Average training loss: 0.0038

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:09<00:00, 26.85it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.51      0.46      0.48      6000
           1       0.69      0.61      0.65      6000
           2       0.67      0.73      0.70      6000
           3       0.50      0.38      0.43      6000
           4       0.75      0.81      0.78      6000
           5       0.84      0.81      0.83      6000
           6       0.45      0.43      0.44      6000
           7       0.59      0.66      0.62      6000
           8       0.59      0.74      0.66      6000
           9       0.68      0.69      0.69      6000

    accuracy                           0.63     60000
   macro avg       0.63      0.63      0.63     60000
weighted avg       0.63      0.63      0.63     60000


Epoch 22/30


Training: 100%|██████████| 3125/3125 [02:53<00:00, 18.04it/s, loss=0.0001]


Average training loss: 0.0025

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:10<00:00, 26.50it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.49      0.49      0.49      6000
           1       0.68      0.59      0.63      6000
           2       0.71      0.70      0.71      6000
           3       0.43      0.46      0.44      6000
           4       0.78      0.79      0.78      6000
           5       0.87      0.78      0.82      6000
           6       0.39      0.49      0.43      6000
           7       0.66      0.58      0.62      6000
           8       0.68      0.67      0.68      6000
           9       0.67      0.70      0.68      6000

    accuracy                           0.63     60000
   macro avg       0.64      0.63      0.63     60000
weighted avg       0.64      0.63      0.63     60000


Epoch 23/30


Training: 100%|██████████| 3125/3125 [02:55<00:00, 17.80it/s, loss=0.0000]


Average training loss: 0.0017

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:09<00:00, 26.85it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.54      0.45      0.49      6000
           1       0.67      0.60      0.63      6000
           2       0.73      0.68      0.70      6000
           3       0.45      0.42      0.44      6000
           4       0.70      0.84      0.76      6000
           5       0.84      0.80      0.82      6000
           6       0.36      0.51      0.42      6000
           7       0.60      0.63      0.61      6000
           8       0.74      0.59      0.66      6000
           9       0.69      0.66      0.68      6000

    accuracy                           0.62     60000
   macro avg       0.63      0.62      0.62     60000
weighted avg       0.63      0.62      0.62     60000


Epoch 24/30


Training: 100%|██████████| 3125/3125 [02:56<00:00, 17.68it/s, loss=0.0003]


Average training loss: 0.0014

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:10<00:00, 26.60it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.54      0.45      0.49      6000
           1       0.68      0.61      0.64      6000
           2       0.72      0.68      0.70      6000
           3       0.47      0.43      0.45      6000
           4       0.75      0.82      0.78      6000
           5       0.84      0.82      0.83      6000
           6       0.40      0.48      0.44      6000
           7       0.60      0.65      0.62      6000
           8       0.67      0.68      0.67      6000
           9       0.68      0.70      0.69      6000

    accuracy                           0.63     60000
   macro avg       0.63      0.63      0.63     60000
weighted avg       0.63      0.63      0.63     60000


Epoch 25/30


Training: 100%|██████████| 3125/3125 [02:58<00:00, 17.53it/s, loss=0.0004]


Average training loss: 0.0010

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:23<00:00, 22.32it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.48      0.50      0.49      6000
           1       0.68      0.60      0.64      6000
           2       0.66      0.73      0.69      6000
           3       0.46      0.42      0.44      6000
           4       0.78      0.80      0.79      6000
           5       0.85      0.80      0.83      6000
           6       0.43      0.45      0.44      6000
           7       0.63      0.62      0.62      6000
           8       0.64      0.71      0.67      6000
           9       0.70      0.66      0.68      6000

    accuracy                           0.63     60000
   macro avg       0.63      0.63      0.63     60000
weighted avg       0.63      0.63      0.63     60000


Epoch 26/30


Training: 100%|██████████| 3125/3125 [03:30<00:00, 14.84it/s, loss=0.0003]


Average training loss: 0.0006

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:19<00:00, 23.45it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.51      0.47      0.49      6000
           1       0.68      0.61      0.64      6000
           2       0.69      0.71      0.70      6000
           3       0.44      0.45      0.45      6000
           4       0.79      0.78      0.79      6000
           5       0.85      0.81      0.83      6000
           6       0.41      0.47      0.44      6000
           7       0.66      0.59      0.62      6000
           8       0.64      0.70      0.67      6000
           9       0.68      0.70      0.69      6000

    accuracy                           0.63     60000
   macro avg       0.63      0.63      0.63     60000
weighted avg       0.63      0.63      0.63     60000


Epoch 27/30


Training: 100%|██████████| 3125/3125 [03:58<00:00, 13.08it/s, loss=0.0001]


Average training loss: 0.0004

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:28<00:00, 21.13it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.49      0.49      0.49      6000
           1       0.65      0.63      0.64      6000
           2       0.71      0.68      0.70      6000
           3       0.45      0.43      0.44      6000
           4       0.75      0.81      0.78      6000
           5       0.86      0.79      0.83      6000
           6       0.40      0.48      0.44      6000
           7       0.65      0.60      0.62      6000
           8       0.68      0.67      0.67      6000
           9       0.68      0.69      0.69      6000

    accuracy                           0.63     60000
   macro avg       0.63      0.63      0.63     60000
weighted avg       0.63      0.63      0.63     60000


Epoch 28/30


Training: 100%|██████████| 3125/3125 [04:10<00:00, 12.45it/s, loss=0.0000]


Average training loss: 0.0002

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:29<00:00, 20.85it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.48      0.51      0.49      6000
           1       0.67      0.61      0.64      6000
           2       0.70      0.70      0.70      6000
           3       0.45      0.43      0.44      6000
           4       0.78      0.79      0.79      6000
           5       0.83      0.81      0.82      6000
           6       0.43      0.44      0.44      6000
           7       0.65      0.61      0.63      6000
           8       0.63      0.71      0.67      6000
           9       0.69      0.68      0.69      6000

    accuracy                           0.63     60000
   macro avg       0.63      0.63      0.63     60000
weighted avg       0.63      0.63      0.63     60000


Epoch 29/30


Training: 100%|██████████| 3125/3125 [04:01<00:00, 12.93it/s, loss=0.0001]


Average training loss: 0.0002

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:26<00:00, 21.60it/s]



Classification Report:
              precision    recall  f1-score   support

           0       0.49      0.50      0.49      6000
           1       0.66      0.62      0.64      6000
           2       0.69      0.70      0.70      6000
           3       0.45      0.45      0.45      6000
           4       0.78      0.79      0.79      6000
           5       0.86      0.79      0.82      6000
           6       0.42      0.43      0.43      6000
           7       0.63      0.62      0.62      6000
           8       0.67      0.69      0.68      6000
           9       0.67      0.71      0.69      6000

    accuracy                           0.63     60000
   macro avg       0.63      0.63      0.63     60000
weighted avg       0.63      0.63      0.63     60000


Epoch 30/30


Training: 100%|██████████| 3125/3125 [03:50<00:00, 13.56it/s, loss=0.0000]


Average training loss: 0.0001

Evaluating...


Evaluating: 100%|██████████| 1875/1875 [01:32<00:00, 20.35it/s]


Classification Report:
              precision    recall  f1-score   support

           0       0.48      0.51      0.49      6000
           1       0.67      0.61      0.64      6000
           2       0.68      0.71      0.70      6000
           3       0.45      0.44      0.45      6000
           4       0.77      0.80      0.79      6000
           5       0.86      0.79      0.83      6000
           6       0.41      0.46      0.43      6000
           7       0.63      0.62      0.62      6000
           8       0.68      0.67      0.67      6000
           9       0.69      0.68      0.68      6000

    accuracy                           0.63     60000
   macro avg       0.63      0.63      0.63     60000
weighted avg       0.63      0.63      0.63     60000




