## Importing of Libraries


In [2]:
import time
from copy import deepcopy

import numpy as np
import torch
from dlordinal.datasets import FGNet
from dlordinal.models import OBDECOCModel
from dlordinal.losses import OrdinalECOCDistanceLoss
from sklearn.metrics import (accuracy_score, cohen_kappa_score,
                             confusion_matrix, mean_absolute_error)
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.utils import class_weight
from torch import cuda
from torch.optim import Adam
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor
from torchvision.models import resnet18
from tqdm import tqdm

## Load and preprocess of FGNet dataset

First, we present the configuration parameters for the experimentation and the number of workers for the *DataLoader*, which defines the number of subprocesses to use for data loading. In this specific case, it refers to the images.

In [3]:
optimiser_params = {
    'lr': 1e-3,
    'bs': 200,
    'epochs': 5,
    's': 2,
    'c': 0.2,
    'beta': 0.5
}

workers = 3

Now we use the *FGNet* method to download and preprocess the images. Once that is done with the training data, we create a validation partition comprising 15% of the data using the *StratifiedShuffleSplit* method. Finally, with all the partitions, we load the images using a method called *DataLoader*.

In [7]:
fgnet = FGNet(root="./datasets/fgnet", download=True, process_data=True)

complete_train_data = ImageFolder(
    root="./datasets/fgnet/FGNET/train", transform=Compose([ToTensor()])
)
test_data = ImageFolder(
    root="./datasets/fgnet/FGNET/test", transform=Compose([ToTensor()])
)

num_classes = len(complete_train_data.classes)
classes = complete_train_data.classes
targets = complete_train_data.targets

# Create a validation split
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.15, random_state=0)
sss_splits = list(
    sss.split(X=np.zeros(len(complete_train_data)), y=complete_train_data.targets)
)
train_idx, val_idx = sss_splits[0]

# Create subsets for training and validation
train_data = Subset(complete_train_data, train_idx)
val_data = Subset(complete_train_data, val_idx)

# Get CUDA device
device = "cuda" if cuda.is_available() else "cpu"
print(f"Using {device} device")

# Create dataloaders
train_dataloader = DataLoader(
    train_data, batch_size=optimiser_params["bs"], shuffle=True, num_workers=workers
)
val_dataloader = DataLoader(
    val_data, batch_size=optimiser_params["bs"], shuffle=True, num_workers=workers
)
test_dataloader = DataLoader(
    test_data, batch_size=optimiser_params["bs"], shuffle=False, num_workers=workers
)

# Get image shape
img_shape = None
for X, _ in train_dataloader:
    img_shape = list(X.shape[1:])
    break
print(f"Detected image shape: {img_shape}")

# Define class weights for imbalanced datasets
class_weights = torch.from_numpy(class_weight.compute_class_weight(
    "balanced", classes=[int(c) for c in classes], y=targets
)).float().to(device)
print(f"{class_weights=}")

Files already downloaded and verified
Files already processed and verified
Files already split and verified
Using cpu device
Detected image shape: [3, 128, 128]
class_weights=tensor([1.5523, 0.5256, 1.1033, 0.8241, 1.0680, 2.5189])


## Model and optimiser

We are using a modified version of the *ResNet* architecture, specifically designed for the loss function explained in the next section. In this case, the *ResNet* model is not pretrained with *ImageNet*, so we will need to undergo an extensive learning process.

To adapt the outputs of the model to this, the final fully-connected block is substituted by $Q-1$ fully-connected blocks [1], each one with a single output unit with sigmoid activation.

As an alternative to the *ResNet* architecture, you can use the *VGG* architecture, which has also been implemented to work with the loss function explained in the following sections.

Finally, we define the *Adam* optimiser, which is used to adjust the network's weights and minimize the error of a loss function.

[1]: Barbero-Gómez, J., Gutiérrez, P. A., & Hervás-Martínez, C. (2022). *Error-correcting output codes in the framework of deep ordinal classification.* Neural Processing Letters, 1-32. Springer.

In [8]:
model = OBDECOCModel(num_classes, resnet18(num_classes=1000), base_n_outputs=1000).to(device)

# Optimizer
optimizer = Adam(model.parameters(), lr=optimiser_params['lr'])

## Loss Function

The original $Q$-class ordinal problem is decomposed into $Q-1$ binary decision problems, what is known as Ordinal Binary Decomposition (ODB) [1]. So the categorical cross-entropy has been substituted by the Mean Squared Error loss because it copes better with the distance function used for the Error-Conecting Output Codes (ECOC) decision:

$$
\ell = \frac{1}{N} ∑_{i=1}^N ∑_{k=1}^{Q-1} (\mathbf{1} \{y_i \succ \mathcal{C}_k\} - P(y_i \succ \mathcal{C}_k | x_i))^2
$$

where $\mathbf{1} \{y_i \succ \mathcal{C}_k\}$ is the indicator function that is equal to 1 when $y_i \succ \mathcal{C}_k$ and 0 otherwise, and $P(y_i \succ \mathcal{C}_k | x_i)$ is the probability that $y_i \succ \mathcal{C}_k$ predicted by the network given a sample.

[1]: Barbero-Gómez, J., Gutiérrez, P. A., & Hervás-Martínez, C. (2022). *Error-correcting output codes in the framework of deep ordinal classification.* Neural Processing Letters, 1-32. Springer. doi.org/10.1007/s11063-022-10824-7

In [9]:
loss_fn = OrdinalECOCDistanceLoss(
    num_classes=num_classes, weights=class_weights
).to(device)

## Metrics

In [10]:
# Metrics computation


def compute_metrics(y_true: np.ndarray, 
    y_pred: np.ndarray, 
    num_classes: int):

    if len(y_true.shape) > 1:
        y_true = np.argmax(y_true, axis=1)

    if len(y_pred.shape) > 1:
        y_pred = np.argmax(y_pred, axis=1)

    labels = range(0, num_classes)

    # Metrics calculation
    qwk = cohen_kappa_score(y_true, y_pred, weights='quadratic', labels=labels)
    ms = minimum_sensitivity(y_true, y_pred, labels=labels)
    mae = mean_absolute_error(y_true, y_pred)
    acc = accuracy_score(y_true, y_pred)
    off1 = accuracy_off1(y_true, y_pred, labels=labels)
    conf_mat = confusion_matrix(y_true, y_pred, labels=labels)

    metrics = {
        'QWK': qwk,
        'MS': ms,
        'MAE': mae,
        'CCR': acc,
        '1-off': off1,
        'Confusion matrix': conf_mat
    }

    return metrics


def _compute_sensitivities(y_true, y_pred, labels=None):
	if len(y_true.shape) > 1:
		y_true = np.argmax(y_true, axis=1)
	if len(y_pred.shape) > 1:
		y_pred = np.argmax(y_pred, axis=1)

	conf_mat = confusion_matrix(y_true, y_pred, labels=labels)

	sum = np.sum(conf_mat, axis=1)
	mask = np.eye(conf_mat.shape[0], conf_mat.shape[1])
	correct = np.sum(conf_mat * mask, axis=1)
	sensitivities = correct / sum

	sensitivities = sensitivities[~np.isnan(sensitivities)]

	return sensitivities


def minimum_sensitivity(y_true, y_pred, labels=None):
	return np.min(_compute_sensitivities(y_true, y_pred, labels=labels))


def accuracy_off1(y_true, y_pred, labels=None):
	if len(y_true.shape) > 1:
		y_true = np.argmax(y_true, axis=1)
	if len(y_pred.shape) > 1:
		y_pred = np.argmax(y_pred, axis=1)

	conf_mat = confusion_matrix(y_true, y_pred, labels=labels)
	n = conf_mat.shape[0]
	mask = np.eye(n, n) + np.eye(n, n, k=1), + np.eye(n, n, k=-1)
	correct = mask * conf_mat

	return 1.0 * np.sum(correct) / np.sum(conf_mat)


def print_metrics(metrics):
    print("")
    print('Confusion matrix :\n{}'.format(metrics['Confusion matrix']))
    print("")
    print('MS: {:.4f}'.format(metrics['MS']))
    print("")
    print('QWK: {:.4f}'.format(metrics['QWK']))
    print("")
    print('MAE: {:.4f}'.format(metrics['MAE']))
    print("")
    print('CCR: {:.4f}'.format(metrics['CCR']))
    print("")
    print('1-off: {:.4f}'.format(metrics['1-off']))

## Training process

In [11]:
def train(
    dataloader: torch.utils.data.DataLoader,
    model: OBDECOCModel,
    loss_fn: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    H: dict,
    num_classes: int,
):  # H: dict
    num_batches = len(dataloader)
    size = len(dataloader.dataset)
    progress_bar = tqdm(total=num_batches, ncols=100, position=0, desc="Train progress")
    model.train()
    mean_loss, accuracy = 0, 0
    y_pred, y_true = None, None

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)  # Inputs and labels to device

        # Compute prediction error and accuracy of the training process
        pred = model(X)
        print(pred)
        print(y)
        loss = loss_fn(pred, y)

        mean_loss += loss
        accuracy += (pred.argmax(1) == y).type(torch.float).sum().item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Stack predictions and true labels to determine the confusion matrix
        pred_np = model.transformer.labels(pred).cpu().detach().numpy()
        true_np = y.cpu().detach().numpy()
        if y_pred is None:
            y_pred = pred_np
        else:
            y_pred = np.concatenate((y_pred, pred_np))

        if y_true is None:
            y_true = true_np
        else:
            y_true = np.concatenate((y_true, true_np))

        # Update progress bar
        progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy)
        progress_bar.update(1)

    accuracy /= size
    mean_loss /= num_batches

    H["train_loss"].append(loss.cpu().detach().numpy())
    H["train_acc"].append(accuracy)

    # Confusion matrix for training
    labels = range(0, num_classes)
    conf_mat = confusion_matrix(y_true, y_pred, labels=labels)
    print("")
    print("Train Confusion matrix :\n{}".format(conf_mat))
    print("")

    return accuracy, mean_loss

## Testing process

In [12]:
def test(
    test_dataloader: torch.utils.data.DataLoader,
    model: OBDECOCModel,
    loss_fn: torch.nn.Module,
    device: torch.device,
    num_classes: int,
):
    num_batches = len(test_dataloader)
    progress_bar = tqdm(total=num_batches, ncols=100, position=0, desc="Test progress")
    model.eval()
    test_loss = 0
    y_pred, y_true = None, None

    with torch.no_grad():
        for batch, (X, y) in enumerate(test_dataloader):
            X, y = X.to(device), y.to(device)  # inputs and labels to device
            pred = model(X)
            test_loss += loss_fn(pred, y).item()

            # Stack predictions and true labels
            pred_np = model.transformer.labels(pred).cpu().detach().numpy()
            true_np = y.cpu().detach().numpy()
            if y_pred is None:
                y_pred = pred_np
            else:
                y_pred = np.concatenate((y_pred, pred_np))

            if y_true is None:
                y_true = true_np
            else:
                y_true = np.concatenate((y_true, true_np))

            # Update progress bar
            progress_bar.set_postfix(loss=test_loss / (batch + 1))
            progress_bar.update(1)

    test_loss /= num_batches
    metrics = compute_metrics(y_true, y_pred, num_classes)
    print_metrics(metrics)

    return metrics, test_loss

## Validation process

In [13]:
def validate(
    dataloader: torch.utils.data.DataLoader,
    model: OBDECOCModel,
    loss_fn: torch.nn.Module,
    device: torch.device,
    H: dict,
    num_classes: int,
):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    loss, accuracy = 0, 0
    y_pred, y_true = None, None

    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)
            pred = model(X)
            loss += loss_fn(pred, y)
            accuracy += (pred.argmax(1) == y).type(torch.float).sum().item()

            pred_np = model.transformer.labels(pred).cpu().detach().numpy()
            true_np = y.cpu().detach().numpy()
            if y_pred is None:
                y_pred = pred_np
            else:
                y_pred = np.concatenate((y_pred, pred_np))

            if y_true is None:
                y_true = true_np
            else:
                y_true = np.concatenate((y_true, true_np))

    accuracy /= size
    loss /= num_batches

    H["val_loss"].append(loss.cpu().detach().numpy())
    H["val_acc"].append(accuracy)

    metrics = compute_metrics(y_true, y_pred, num_classes)

    return metrics, accuracy, loss

## Results

In [14]:
H = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}

# To store validation metrics
validation_metrics = {}

# Definition to store best model weights
best_model_weights = model.state_dict()
best_qwk = 0.0

# Start time
start_time = time.time()

for e in range(optimiser_params["epochs"]):
    train_acc, train_loss = train(
        train_dataloader, model, loss_fn, optimizer, device, H, num_classes=num_classes
    )
    validation_metrics, val_acc, val_loss = validate(
        val_dataloader, model, loss_fn, device, H, num_classes=num_classes
    )

    if validation_metrics["QWK"] >= best_qwk:
        best_qwk = validation_metrics["QWK"]
        best_model_weights = deepcopy(model.state_dict())

    print("[INFO] EPOCH: {}/{}".format(e + 1, optimiser_params["epochs"]))
    print("Train loss: {:.6f}, Train accuracy: {:.4f}".format(train_loss, train_acc))
    print("Val loss: {:.6f}, Val accuracy: {:.4f}\n".format(val_loss, val_acc))

# Store last train error
train_error = H["train_loss"][-1]

# Restore best weights
model.load_state_dict(best_model_weights)

# Start evaluation
print("[INFO] Network evaluation ...")

test_metrics, test_loss = test(
    test_dataloader, model, loss_fn, device, num_classes=num_classes
)

# End time
end_time = time.time()
print("\n[INFO] Total training time: {:.2f}s".format(end_time - start_time))

Train progress:   0%|                                                         | 0/4 [00:00<?, ?it/s]

tensor([[0.4912, 0.5440, 0.5716, 0.5553, 0.4533],
        [0.4801, 0.5443, 0.5829, 0.6065, 0.4515],
        [0.4903, 0.5389, 0.5636, 0.5578, 0.4593],
        [0.4928, 0.5433, 0.5640, 0.5893, 0.4636],
        [0.4999, 0.5485, 0.5699, 0.5603, 0.4691],
        [0.4862, 0.5494, 0.5668, 0.5821, 0.4698],
        [0.4791, 0.5328, 0.5517, 0.5645, 0.4729],
        [0.4375, 0.6252, 0.7488, 0.7413, 0.3644],
        [0.4941, 0.5313, 0.5622, 0.5524, 0.4552],
        [0.4983, 0.5349, 0.5862, 0.5830, 0.4612],
        [0.4672, 0.5675, 0.6398, 0.6316, 0.4874],
        [0.4821, 0.5367, 0.5761, 0.5522, 0.4565],
        [0.4852, 0.5478, 0.5606, 0.5579, 0.4540],
        [0.4795, 0.5650, 0.5509, 0.5663, 0.4736],
        [0.4718, 0.5532, 0.5644, 0.5718, 0.4818],
        [0.4679, 0.5675, 0.5650, 0.5655, 0.4665],
        [0.4932, 0.5563, 0.6577, 0.6633, 0.4699],
        [0.5057, 0.5309, 0.5646, 0.5648, 0.4514],
        [0.4898, 0.5299, 0.5731, 0.5509, 0.4546],
        [0.4771, 0.5542, 0.5730, 0.5579, 0.4536],


Train progress:  25%|██████▌                   | 1/4 [00:04<00:13,  4.65s/it, accuracy=39, loss=247]

tensor([[0.9139, 0.7853, 0.5700, 0.3353, 0.1514],
        [0.9003, 0.7873, 0.5944, 0.3545, 0.1293],
        [0.9694, 0.8326, 0.1579, 0.0480, 0.0520],
        [0.9099, 0.7803, 0.4800, 0.2738, 0.1233],
        [0.9306, 0.8188, 0.5699, 0.3504, 0.1370],
        [0.9394, 0.7702, 0.3242, 0.1758, 0.1025],
        [0.9107, 0.7885, 0.4797, 0.2544, 0.1268],
        [0.9571, 0.8189, 0.2473, 0.0989, 0.0865],
        [0.9079, 0.7911, 0.5795, 0.3759, 0.1344],
        [0.9065, 0.7888, 0.6060, 0.3557, 0.1395],
        [0.9172, 0.8066, 0.6392, 0.3711, 0.1283],
        [0.9075, 0.7870, 0.5617, 0.3215, 0.1370],
        [0.9215, 0.8045, 0.5229, 0.3068, 0.1261],
        [0.9176, 0.7832, 0.4543, 0.2241, 0.1156],
        [0.9156, 0.7820, 0.5722, 0.3071, 0.1388],
        [0.9058, 0.7824, 0.5835, 0.3727, 0.1328],
        [0.9089, 0.7883, 0.5883, 0.3744, 0.1441],
        [0.9253, 0.8015, 0.5002, 0.2497, 0.1266],
        [0.8983, 0.7731, 0.5641, 0.3209, 0.1185],
        [0.9250, 0.7745, 0.3828, 0.1647, 0.1129],


Train progress:  50%|█████████████             | 2/4 [00:09<00:09,  4.70s/it, accuracy=67, loss=199]

tensor([[0.9785, 0.6934, 0.3208, 0.1117, 0.0478],
        [0.9777, 0.6861, 0.3116, 0.0756, 0.0425],
        [0.9574, 0.6724, 0.2410, 0.1338, 0.0444],
        [0.9601, 0.6715, 0.2311, 0.1361, 0.0443],
        [0.9826, 0.7742, 0.3565, 0.1235, 0.0347],
        [0.9756, 0.7030, 0.2665, 0.1306, 0.0358],
        [0.9534, 0.6438, 0.2076, 0.1195, 0.0396],
        [0.9590, 0.6378, 0.2452, 0.1280, 0.0485],
        [0.9732, 0.7178, 0.3005, 0.1180, 0.0440],
        [0.9752, 0.7080, 0.2621, 0.1288, 0.0409],
        [0.9735, 0.7043, 0.2604, 0.1131, 0.0382],
        [0.9769, 0.7239, 0.3007, 0.1310, 0.0403],
        [0.9554, 0.6692, 0.2398, 0.1300, 0.0454],
        [0.9876, 0.6552, 0.2925, 0.0242, 0.0191],
        [0.9705, 0.7382, 0.2932, 0.1510, 0.0440],
        [0.9497, 0.6138, 0.1990, 0.0943, 0.0417],
        [0.9792, 0.6655, 0.2851, 0.0832, 0.0417],
        [0.9542, 0.6307, 0.2154, 0.1234, 0.0435],
        [0.9469, 0.5935, 0.2209, 0.0751, 0.0381],
        [0.9582, 0.6351, 0.2254, 0.1182, 0.0435],


Train progress:  75%|███████████████████▌      | 3/4 [00:13<00:04,  4.45s/it, accuracy=91, loss=266]

tensor([[9.7524e-01, 4.5278e-01, 3.5250e-01, 6.5401e-02, 2.1024e-02],
        [9.6925e-01, 5.4145e-01, 3.8620e-01, 1.1418e-01, 2.9808e-02],
        [9.7728e-01, 7.2764e-01, 4.8152e-01, 2.0032e-01, 3.2922e-02],
        [9.9179e-01, 7.5327e-01, 5.3227e-01, 1.5993e-01, 2.3001e-02],
        [9.8410e-01, 6.4503e-01, 4.4581e-01, 1.3399e-01, 2.8116e-02],
        [9.9213e-01, 7.4102e-01, 4.6265e-01, 1.5469e-01, 1.9033e-02],
        [9.8426e-01, 7.1403e-01, 4.7543e-01, 1.6186e-01, 3.0384e-02],
        [9.6650e-01, 5.0367e-01, 4.3495e-01, 1.2102e-01, 3.5086e-02],
        [9.6415e-01, 5.7482e-01, 4.1329e-01, 1.5479e-01, 3.9008e-02],
        [9.6921e-01, 5.9345e-01, 4.2657e-01, 1.6608e-01, 3.9245e-02],
        [9.7697e-01, 7.1310e-01, 4.4186e-01, 1.9257e-01, 3.4124e-02],
        [9.6644e-01, 6.5248e-01, 4.1027e-01, 2.1629e-01, 4.0248e-02],
        [9.5973e-01, 5.0617e-01, 3.8757e-01, 1.1483e-01, 3.2238e-02],
        [9.9370e-01, 7.7114e-01, 5.4950e-01, 1.3514e-01, 1.9331e-02],
        [9.8323e-01,

Train progress: 100%|█████████████████████████| 4/4 [00:15<00:00,  3.84s/it, accuracy=96, loss=80.8]


Train Confusion matrix :
[[  0   2  31  24  16   0]
 [  0   7 108  42  59   0]
 [  0   1  50  20  32   0]
 [  0   2  65  22  48   0]
 [  0   0  53  19  34   0]
 [  0   2  26   6  11   0]]






[INFO] EPOCH: 1/5
Train loss: 198.295624, Train accuracy: 0.1412
Val loss: 152.942474, Val accuracy: 0.1074



Train progress:   0%|                                                         | 0/4 [00:00<?, ?it/s]

tensor([[9.7141e-01, 4.6385e-01, 5.6449e-01, 1.2643e-01, 2.4855e-02],
        [9.7717e-01, 6.4153e-01, 6.2600e-01, 2.3592e-01, 3.9829e-02],
        [9.8805e-01, 7.4448e-01, 6.3891e-01, 2.6260e-01, 2.3412e-02],
        [9.7670e-01, 5.5676e-01, 6.7058e-01, 1.8985e-01, 3.5169e-02],
        [9.9339e-01, 2.6080e-01, 5.5674e-01, 3.2779e-02, 4.7527e-03],
        [9.8532e-01, 4.4277e-01, 5.9895e-01, 8.0722e-02, 1.3062e-02],
        [9.6671e-01, 6.1758e-01, 5.6144e-01, 2.7682e-01, 4.2791e-02],
        [9.7320e-01, 6.2929e-01, 5.7589e-01, 2.5844e-01, 4.1379e-02],
        [9.7795e-01, 6.5733e-01, 7.0831e-01, 2.3680e-01, 4.7053e-02],
        [9.5978e-01, 5.5575e-01, 5.4827e-01, 1.9089e-01, 3.6764e-02],
        [9.8402e-01, 7.0206e-01, 6.5886e-01, 2.4167e-01, 3.0940e-02],
        [9.7624e-01, 6.4555e-01, 5.8502e-01, 2.5512e-01, 3.2508e-02],
        [9.9944e-01, 1.8227e-01, 5.6530e-01, 4.9934e-03, 4.1117e-04],
        [9.9067e-01, 7.4465e-01, 6.3358e-01, 1.9323e-01, 1.9644e-02],
        [9.7975e-01,

Train progress:  25%|██████▌                   | 1/4 [00:04<00:13,  4.34s/it, accuracy=17, loss=209]

tensor([[9.8906e-01, 8.0851e-01, 7.0053e-01, 4.1692e-01, 2.8588e-02],
        [9.8824e-01, 7.3387e-01, 6.3560e-01, 3.3888e-01, 3.2061e-02],
        [9.9724e-01, 5.3799e-01, 4.8648e-01, 5.6775e-02, 2.5350e-03],
        [9.5685e-01, 4.6325e-01, 4.8198e-01, 2.4195e-01, 3.4154e-02],
        [9.9178e-01, 6.7978e-01, 6.0443e-01, 1.8062e-01, 1.2253e-02],
        [9.7144e-01, 7.2649e-01, 6.8320e-01, 4.8203e-01, 5.6040e-02],
        [9.8897e-01, 8.1645e-01, 7.4342e-01, 4.3221e-01, 3.1319e-02],
        [9.9062e-01, 8.1915e-01, 7.4018e-01, 4.6326e-01, 2.6573e-02],
        [9.9362e-01, 7.2645e-01, 6.4738e-01, 1.8773e-01, 1.1797e-02],
        [9.9347e-01, 6.5979e-01, 5.8901e-01, 1.4883e-01, 1.0317e-02],
        [9.8951e-01, 8.2295e-01, 7.4223e-01, 4.7236e-01, 3.3159e-02],
        [9.5484e-01, 5.7584e-01, 5.0548e-01, 3.3583e-01, 4.9076e-02],
        [9.7595e-01, 7.2636e-01, 6.5431e-01, 4.4562e-01, 4.3999e-02],
        [9.8000e-01, 6.9109e-01, 5.6843e-01, 3.5059e-01, 3.4826e-02],
        [9.7276e-01,

Train progress:  50%|█████████████             | 2/4 [00:09<00:10,  5.01s/it, accuracy=45, loss=189]

tensor([[9.5943e-01, 6.9081e-01, 6.1861e-01, 6.0642e-01, 7.3921e-02],
        [9.9113e-01, 7.9536e-01, 6.1547e-01, 5.0470e-01, 2.7425e-02],
        [9.9831e-01, 4.6628e-01, 1.9182e-01, 9.6293e-02, 2.1086e-03],
        [9.8066e-01, 7.0363e-01, 5.1705e-01, 4.5700e-01, 3.7348e-02],
        [9.9018e-01, 2.8614e-01, 1.7643e-01, 1.3507e-01, 5.4460e-03],
        [9.6457e-01, 6.2357e-01, 4.9656e-01, 4.4396e-01, 4.9537e-02],
        [9.9064e-01, 4.1492e-01, 2.5434e-01, 1.7908e-01, 1.0360e-02],
        [9.4761e-01, 4.4704e-01, 3.6491e-01, 3.5627e-01, 5.0085e-02],
        [9.8699e-01, 7.2069e-01, 5.2708e-01, 3.8976e-01, 2.3781e-02],
        [9.8149e-01, 7.5373e-01, 6.3498e-01, 6.0939e-01, 5.0815e-02],
        [9.7608e-01, 6.7210e-01, 5.6031e-01, 5.5692e-01, 4.3723e-02],
        [9.9012e-01, 4.2302e-01, 2.5017e-01, 1.9879e-01, 9.3497e-03],
        [9.8800e-01, 5.5343e-01, 3.7650e-01, 2.8998e-01, 1.4138e-02],
        [9.6185e-01, 4.4559e-01, 3.4086e-01, 2.9329e-01, 3.5417e-02],
        [9.6047e-01,

Train progress:  75%|███████████████████▌      | 3/4 [00:14<00:04,  4.71s/it, accuracy=67, loss=168]

tensor([[9.8636e-01, 8.1700e-01, 6.1502e-01, 5.6831e-01, 4.0832e-02],
        [9.9555e-01, 8.5240e-01, 5.1746e-01, 3.7773e-01, 1.4791e-02],
        [9.5870e-01, 7.1748e-01, 5.7020e-01, 5.7243e-01, 7.5596e-02],
        [9.8390e-01, 4.0748e-01, 1.8471e-01, 2.0344e-01, 1.3017e-02],
        [9.3585e-01, 6.1451e-01, 5.2184e-01, 5.1064e-01, 8.2748e-02],
        [9.9377e-01, 8.8971e-01, 7.1827e-01, 6.2724e-01, 2.5814e-02],
        [9.5705e-01, 3.5728e-01, 2.4297e-01, 3.4559e-01, 3.6580e-02],
        [9.3589e-01, 6.5631e-01, 5.4505e-01, 5.6635e-01, 1.0158e-01],
        [9.9423e-01, 2.5769e-01, 8.5336e-02, 1.5211e-01, 3.8300e-03],
        [9.4434e-01, 6.9509e-01, 5.9165e-01, 5.7020e-01, 9.4543e-02],
        [9.9784e-01, 9.2100e-01, 7.2266e-01, 6.1859e-01, 1.5135e-02],
        [9.1855e-01, 4.4188e-01, 3.6319e-01, 4.3025e-01, 7.6860e-02],
        [9.9749e-01, 5.8919e-01, 1.9762e-01, 1.6705e-01, 4.0442e-03],
        [9.9213e-01, 8.5729e-01, 6.6161e-01, 5.1929e-01, 2.9364e-02],
        [9.3267e-01,

Train progress: 100%|█████████████████████████| 4/4 [00:15<00:00,  4.00s/it, accuracy=73, loss=60.9]


Train Confusion matrix :
[[  0  31  10  29   3   0]
 [  0  64  16 116  20   0]
 [  0  13   7  53  30   0]
 [  0  14  12  76  35   0]
 [  0   3   8  70  25   0]
 [  0   0   0  31  14   0]]






[INFO] EPOCH: 2/5
Train loss: 156.812943, Train accuracy: 0.1074
Val loss: 222.719940, Val accuracy: 0.1074



Train progress:   0%|                                                         | 0/4 [00:00<?, ?it/s]

tensor([[9.8702e-01, 6.7866e-01, 2.8459e-01, 2.2914e-01, 2.0235e-02],
        [9.9234e-01, 8.9618e-01, 6.6734e-01, 5.9902e-01, 3.5739e-02],
        [9.9621e-01, 9.2202e-01, 7.0811e-01, 6.6555e-01, 2.7079e-02],
        [9.5333e-01, 3.8661e-01, 2.1204e-01, 3.3553e-01, 4.4102e-02],
        [9.5187e-01, 5.7586e-01, 3.3841e-01, 3.9665e-01, 6.1984e-02],
        [9.8912e-01, 8.7745e-01, 6.8436e-01, 6.6403e-01, 4.9505e-02],
        [9.8955e-01, 8.8298e-01, 6.9779e-01, 6.5825e-01, 4.6130e-02],
        [9.9619e-01, 9.2266e-01, 6.7706e-01, 6.1275e-01, 2.4410e-02],
        [9.9869e-01, 9.5043e-01, 6.2703e-01, 4.3970e-01, 1.1786e-02],
        [9.8827e-01, 8.4487e-01, 6.0723e-01, 6.0198e-01, 4.7330e-02],
        [9.9193e-01, 8.9596e-01, 6.3819e-01, 5.0489e-01, 3.6345e-02],
        [9.9864e-01, 9.4052e-01, 5.6927e-01, 3.8138e-01, 1.1365e-02],
        [9.1963e-01, 5.7298e-01, 3.9011e-01, 3.9807e-01, 8.9996e-02],
        [9.9980e-01, 9.7675e-01, 6.1976e-01, 3.9047e-01, 3.6591e-03],
        [9.9683e-01,

Train progress:  25%|██████▌                   | 1/4 [00:04<00:14,  4.96s/it, accuracy=16, loss=152]

tensor([[9.9014e-01, 5.2285e-01, 1.5087e-01, 1.8085e-01, 1.3515e-02],
        [9.8908e-01, 3.4743e-01, 7.8334e-02, 1.0934e-01, 9.0382e-03],
        [9.1380e-01, 7.2842e-01, 5.7117e-01, 4.7853e-01, 1.4116e-01],
        [9.9375e-01, 9.2812e-01, 6.9969e-01, 5.1781e-01, 3.6054e-02],
        [9.9994e-01, 6.9140e-02, 9.8581e-04, 1.0739e-02, 3.3051e-05],
        [9.9040e-01, 1.8820e-01, 3.5116e-02, 9.4459e-02, 4.8272e-03],
        [9.9998e-01, 9.9282e-01, 5.7932e-01, 2.0788e-01, 9.0462e-04],
        [9.6792e-01, 8.0982e-01, 6.0824e-01, 5.0476e-01, 8.6447e-02],
        [9.9708e-01, 1.1197e-01, 1.1671e-02, 4.7337e-02, 1.1296e-03],
        [9.6388e-01, 7.7638e-01, 5.2083e-01, 4.2827e-01, 7.3981e-02],
        [9.6446e-01, 7.9694e-01, 5.6963e-01, 4.5900e-01, 8.0966e-02],
        [9.9602e-01, 2.0750e-01, 2.9711e-02, 6.8647e-02, 3.0012e-03],
        [9.9990e-01, 6.1647e-02, 9.7756e-04, 1.1022e-02, 3.4976e-05],
        [9.2036e-01, 5.2136e-01, 3.3110e-01, 3.3528e-01, 8.0244e-02],
        [9.7435e-01,

Train progress:  50%|█████████████             | 2/4 [00:09<00:09,  4.96s/it, accuracy=40, loss=158]

tensor([[8.1158e-01, 5.7839e-01, 4.2049e-01, 3.2309e-01, 1.6914e-01],
        [9.0280e-01, 3.2219e-01, 1.2387e-01, 1.2857e-01, 3.6357e-02],
        [9.9980e-01, 1.1543e-01, 1.0314e-03, 8.6337e-03, 7.4386e-05],
        [8.9966e-01, 7.5120e-01, 5.7749e-01, 4.2908e-01, 1.6375e-01],
        [9.9737e-01, 2.1725e-01, 1.1549e-02, 3.8106e-02, 1.3622e-03],
        [9.9114e-01, 9.2789e-01, 7.3725e-01, 5.0248e-01, 5.7103e-02],
        [9.8960e-01, 9.2299e-01, 7.6805e-01, 5.6223e-01, 6.5640e-02],
        [9.3213e-01, 5.4897e-01, 2.3598e-01, 1.8797e-01, 5.2160e-02],
        [9.3878e-01, 6.7979e-01, 4.2413e-01, 2.9126e-01, 8.2209e-02],
        [8.2882e-01, 4.0631e-01, 2.2528e-01, 1.8678e-01, 7.7367e-02],
        [9.7850e-01, 3.2312e-01, 6.0708e-02, 1.0123e-01, 1.2443e-02],
        [9.9898e-01, 9.6324e-01, 5.6947e-01, 2.1447e-01, 1.1110e-02],
        [9.9991e-01, 9.9275e-01, 8.1270e-01, 3.4350e-01, 4.5758e-03],
        [8.5294e-01, 5.1659e-01, 3.2950e-01, 2.5321e-01, 1.0490e-01],
        [9.1283e-01,

Train progress:  75%|███████████████████▌      | 3/4 [00:14<00:04,  4.77s/it, accuracy=64, loss=184]

tensor([[9.9676e-01, 9.6197e-01, 7.8791e-01, 5.6312e-01, 4.6399e-02],
        [9.5950e-01, 7.4022e-01, 4.0548e-01, 2.7129e-01, 8.1833e-02],
        [9.9957e-01, 9.7494e-01, 4.9705e-01, 1.3748e-01, 7.6419e-03],
        [9.9684e-01, 9.6381e-01, 8.4563e-01, 6.4633e-01, 5.4985e-02],
        [9.7190e-01, 8.6012e-01, 6.2057e-01, 3.6092e-01, 8.8955e-02],
        [9.3509e-01, 6.4382e-01, 3.7609e-01, 2.8075e-01, 8.2666e-02],
        [9.6855e-01, 8.8906e-01, 7.6756e-01, 5.7572e-01, 1.5220e-01],
        [9.2498e-01, 3.2749e-01, 9.5988e-02, 9.3133e-02, 2.8121e-02],
        [9.4375e-01, 7.7577e-01, 5.3626e-01, 3.4122e-01, 1.1250e-01],
        [9.8937e-01, 8.9544e-01, 5.1869e-01, 2.3036e-01, 4.2627e-02],
        [9.8393e-01, 2.2940e-01, 2.1273e-02, 4.9125e-02, 5.2294e-03],
        [9.9996e-01, 9.9363e-01, 6.5432e-01, 2.0055e-01, 2.9413e-03],
        [9.9967e-01, 9.8826e-01, 8.6204e-01, 6.0901e-01, 1.7447e-02],
        [9.9934e-01, 2.0566e-01, 2.3305e-03, 8.9037e-03, 2.2459e-04],
        [9.9762e-01,

Train progress: 100%|███████████████████████████| 4/4 [00:16<00:00,  4.13s/it, accuracy=73, loss=59]


Train Confusion matrix :
[[  0  57  10   3   3   0]
 [  0 110  61  31  14   0]
 [  0  26  34  31  12   0]
 [  0  26  25  58  28   0]
 [  0  11  18  37  40   0]
 [  0   2   7  17  19   0]]






[INFO] EPOCH: 3/5
Train loss: 138.412659, Train accuracy: 0.1074
Val loss: 176.453369, Val accuracy: 0.1074



Train progress:   0%|                                                         | 0/4 [00:00<?, ?it/s]

tensor([[9.7345e-01, 7.0695e-01, 3.2657e-01, 1.9400e-01, 4.9626e-02],
        [9.7275e-01, 3.6306e-01, 7.2629e-02, 5.2764e-02, 1.2981e-02],
        [9.0933e-01, 7.7574e-01, 6.1937e-01, 4.4664e-01, 2.1256e-01],
        [7.1294e-01, 4.6626e-01, 3.2597e-01, 2.0635e-01, 1.5411e-01],
        [9.9853e-01, 9.7504e-01, 8.5381e-01, 6.6955e-01, 4.9231e-02],
        [9.8361e-01, 8.9519e-01, 6.1336e-01, 3.1597e-01, 8.7575e-02],
        [9.6465e-01, 8.7949e-01, 7.6484e-01, 6.1372e-01, 1.8128e-01],
        [8.6906e-01, 7.4136e-01, 5.7944e-01, 3.6470e-01, 2.2222e-01],
        [9.2154e-01, 6.4030e-01, 3.0850e-01, 1.5320e-01, 6.9678e-02],
        [9.9864e-01, 9.4981e-01, 4.4884e-01, 1.4776e-01, 1.5441e-02],
        [7.6630e-01, 4.5788e-01, 2.9296e-01, 1.7264e-01, 1.1825e-01],
        [9.9912e-01, 9.8215e-01, 8.9521e-01, 7.7777e-01, 4.5598e-02],
        [9.9370e-01, 9.5102e-01, 8.3746e-01, 6.5520e-01, 9.9978e-02],
        [9.9628e-01, 9.5794e-01, 7.2878e-01, 3.8594e-01, 5.2631e-02],
        [9.9907e-01,

Train progress:  25%|██████▌                   | 1/4 [00:05<00:15,  5.21s/it, accuracy=20, loss=157]

tensor([[9.9158e-01, 9.4578e-01, 8.6029e-01, 7.1683e-01, 1.6265e-01],
        [9.9904e-01, 9.5996e-01, 5.7075e-01, 1.9385e-01, 2.1031e-02],
        [9.8634e-01, 1.0159e-01, 1.0455e-02, 1.3255e-02, 2.0482e-03],
        [9.9451e-01, 9.5208e-01, 8.3578e-01, 6.5589e-01, 1.2759e-01],
        [9.2262e-01, 7.8799e-01, 6.0695e-01, 3.2670e-01, 1.7856e-01],
        [7.3494e-01, 3.5691e-01, 2.3571e-01, 1.3947e-01, 1.0867e-01],
        [9.8879e-01, 9.3526e-01, 8.4667e-01, 6.8988e-01, 1.7092e-01],
        [9.9836e-01, 1.4041e-01, 4.5938e-03, 5.4093e-03, 4.1122e-04],
        [9.6699e-01, 8.8365e-01, 7.2971e-01, 4.3855e-01, 1.8358e-01],
        [9.8083e-01, 7.8590e-01, 3.9304e-01, 1.4252e-01, 4.8561e-02],
        [6.7978e-01, 5.0697e-01, 4.0905e-01, 2.6800e-01, 2.2116e-01],
        [9.9378e-01, 5.1251e-01, 7.0320e-02, 3.7678e-02, 6.7465e-03],
        [7.4256e-01, 5.5982e-01, 4.4155e-01, 2.7052e-01, 2.2058e-01],
        [9.4047e-01, 7.2645e-01, 4.5596e-01, 2.3062e-01, 1.0639e-01],
        [9.5629e-01,

Train progress:  50%|█████████████             | 2/4 [00:09<00:09,  4.78s/it, accuracy=44, loss=144]

tensor([[9.1315e-01, 1.1353e-01, 3.0696e-02, 1.3867e-02, 7.4285e-03],
        [6.3041e-01, 4.2482e-01, 3.4405e-01, 1.6618e-01, 1.6339e-01],
        [9.8041e-01, 9.0941e-01, 8.1471e-01, 6.2215e-01, 2.5476e-01],
        [9.9705e-01, 9.5034e-01, 7.7364e-01, 4.4816e-01, 1.0391e-01],
        [6.9368e-01, 4.3790e-01, 3.1567e-01, 1.4416e-01, 1.3422e-01],
        [9.5806e-01, 8.7176e-01, 8.0686e-01, 6.6403e-01, 3.1631e-01],
        [9.8819e-01, 9.2962e-01, 8.6994e-01, 7.7828e-01, 2.6360e-01],
        [9.9961e-01, 9.8485e-01, 9.0549e-01, 7.4102e-01, 8.1474e-02],
        [9.9699e-01, 1.0587e-01, 6.4831e-03, 5.1961e-03, 6.3737e-04],
        [9.8910e-01, 9.3634e-01, 8.5521e-01, 6.7102e-01, 2.2218e-01],
        [8.1610e-01, 2.2798e-01, 8.9027e-02, 3.3417e-02, 2.5275e-02],
        [9.9796e-01, 4.3942e-02, 2.0800e-03, 2.4848e-03, 2.0213e-04],
        [7.1761e-01, 3.7080e-01, 2.9016e-01, 1.6557e-01, 1.3286e-01],
        [6.8053e-01, 4.0160e-01, 2.8766e-01, 1.1907e-01, 1.1639e-01],
        [8.9478e-01,

Train progress:  75%|███████████████████▌      | 3/4 [00:13<00:04,  4.35s/it, accuracy=65, loss=137]

tensor([[8.7479e-01, 7.8642e-01, 7.1508e-01, 4.9776e-01, 3.9365e-01],
        [9.9989e-01, 9.5182e-01, 5.7300e-01, 1.5426e-01, 1.6401e-02],
        [9.9883e-01, 1.4053e-02, 6.6290e-04, 3.1136e-04, 3.2175e-05],
        [9.9641e-01, 9.6120e-01, 9.1778e-01, 8.6908e-01, 3.5540e-01],
        [9.8327e-01, 9.1050e-01, 7.7949e-01, 4.7423e-01, 2.5381e-01],
        [9.8650e-01, 3.9574e-01, 1.0081e-01, 2.8371e-02, 1.1047e-02],
        [7.4040e-01, 5.5544e-01, 4.2081e-01, 1.7409e-01, 1.7331e-01],
        [7.0794e-01, 7.4152e-02, 3.4740e-02, 1.4394e-02, 1.1112e-02],
        [9.9805e-01, 9.6673e-01, 8.9421e-01, 7.1783e-01, 2.4380e-01],
        [9.0213e-01, 8.0676e-01, 6.9499e-01, 3.7874e-01, 3.1220e-01],
        [5.6322e-01, 2.1955e-01, 1.4461e-01, 5.7330e-02, 6.1471e-02],
        [9.9999e-01, 9.8690e-01, 6.2545e-01, 1.1559e-01, 6.6501e-03],
        [9.1252e-01, 7.9744e-01, 7.2025e-01, 4.7902e-01, 3.4286e-01],
        [9.9459e-01, 8.4792e-01, 4.4797e-01, 1.7153e-01, 6.3500e-02],
        [9.9896e-01,

Train progress: 100%|█████████████████████████| 4/4 [00:15<00:00,  3.89s/it, accuracy=73, loss=48.3]


Train Confusion matrix :
[[ 0 60  7  4  2  0]
 [ 0 90 71 34 21  0]
 [ 0 18 36 30 19  0]
 [ 0 12 35 49 41  0]
 [ 0  5  7 30 64  0]
 [ 0  1  5  5 34  0]]






[INFO] EPOCH: 4/5
Train loss: 121.257805, Train accuracy: 0.1074
Val loss: 168.987823, Val accuracy: 0.1074



Train progress:   0%|                                                         | 0/4 [00:00<?, ?it/s]

tensor([[9.9998e-01, 9.9416e-01, 9.5223e-01, 8.5698e-01, 1.8323e-01],
        [9.5982e-01, 8.8295e-01, 8.1945e-01, 6.1244e-01, 4.5469e-01],
        [9.9919e-01, 9.7387e-01, 9.0561e-01, 7.3590e-01, 3.1120e-01],
        [5.5568e-01, 4.5201e-01, 4.1361e-01, 1.8017e-01, 2.0533e-01],
        [7.1918e-01, 2.6615e-01, 1.8359e-01, 5.9021e-02, 6.2473e-02],
        [9.2672e-01, 1.6677e-02, 4.1594e-03, 2.4714e-03, 6.6581e-04],
        [5.1707e-01, 1.2240e-01, 8.1918e-02, 2.7222e-02, 2.9733e-02],
        [9.9859e-01, 9.6264e-01, 8.5542e-01, 6.2491e-01, 2.3867e-01],
        [8.5209e-01, 5.8882e-01, 4.3716e-01, 1.3955e-01, 1.3308e-01],
        [6.2536e-01, 4.1514e-01, 2.9201e-01, 7.7601e-02, 9.3799e-02],
        [9.8762e-01, 7.0508e-01, 4.2507e-01, 1.3621e-01, 6.4989e-02],
        [7.1739e-01, 2.6005e-01, 1.2056e-01, 2.7001e-02, 3.2394e-02],
        [4.7001e-01, 3.0200e-01, 2.4734e-01, 9.7526e-02, 1.1793e-01],
        [5.6723e-01, 1.2797e-01, 9.5813e-02, 3.9953e-02, 3.9810e-02],
        [9.2497e-01,

Train progress:  25%|██████▌                   | 1/4 [00:04<00:13,  4.57s/it, accuracy=17, loss=102]

tensor([[7.2171e-01, 4.8664e-01, 3.5633e-01, 1.2187e-01, 1.2924e-01],
        [9.9429e-01, 9.0233e-01, 6.2525e-01, 1.6431e-01, 1.0319e-01],
        [9.9531e-01, 9.6195e-01, 9.2479e-01, 8.3007e-01, 5.6830e-01],
        [3.0214e-01, 9.4073e-02, 6.6118e-02, 1.7400e-02, 2.4403e-02],
        [5.7800e-01, 4.7152e-01, 3.3716e-01, 8.9144e-02, 1.1664e-01],
        [9.9609e-01, 9.5309e-01, 8.8270e-01, 7.6024e-01, 4.3329e-01],
        [7.7406e-01, 1.2591e-01, 6.0795e-02, 1.4554e-02, 1.1818e-02],
        [9.8136e-01, 9.2415e-01, 8.3399e-01, 5.0315e-01, 3.7625e-01],
        [9.1889e-01, 8.0120e-01, 6.5615e-01, 2.7085e-01, 2.7705e-01],
        [6.3459e-01, 3.4536e-01, 2.1065e-01, 3.8944e-02, 4.9409e-02],
        [8.2805e-01, 7.6054e-01, 6.4234e-01, 2.4945e-01, 2.9261e-01],
        [1.0000e+00, 9.9089e-01, 7.9892e-01, 2.7029e-01, 2.1851e-02],
        [9.5129e-01, 4.2276e-01, 1.5441e-01, 2.6180e-02, 1.8834e-02],
        [8.4789e-01, 4.4268e-01, 2.8142e-01, 9.5787e-02, 7.9417e-02],
        [9.3587e-01,

Train progress:  50%|█████████████             | 2/4 [00:08<00:08,  4.30s/it, accuracy=41, loss=104]

tensor([[3.5521e-01, 1.5893e-01, 1.0051e-01, 3.2879e-02, 4.2972e-02],
        [9.9510e-01, 9.4211e-01, 8.2527e-01, 4.6639e-01, 2.8218e-01],
        [2.3856e-01, 2.8437e-02, 1.6993e-02, 5.7402e-03, 5.7702e-03],
        [4.1087e-01, 5.2752e-02, 3.1061e-02, 8.8161e-03, 8.6823e-03],
        [6.5576e-01, 2.1516e-01, 1.3656e-01, 4.2989e-02, 3.9873e-02],
        [2.3170e-01, 1.0001e-01, 6.9722e-02, 2.1422e-02, 2.7213e-02],
        [1.3447e-01, 1.6314e-02, 8.3454e-03, 2.1951e-03, 2.3036e-03],
        [9.9949e-01, 9.8294e-01, 9.2590e-01, 8.5415e-01, 4.6702e-01],
        [9.8527e-01, 9.3057e-01, 8.6864e-01, 7.5617e-01, 5.8350e-01],
        [9.3451e-01, 2.9832e-02, 6.1749e-03, 1.5077e-03, 6.1334e-04],
        [9.9917e-01, 5.0639e-01, 1.0696e-01, 1.8012e-02, 3.5114e-03],
        [9.9495e-01, 9.4612e-01, 8.6859e-01, 7.1604e-01, 4.6735e-01],
        [9.9549e-01, 9.6232e-01, 9.1244e-01, 8.3816e-01, 5.7555e-01],
        [8.1478e-01, 1.5842e-01, 7.8772e-02, 3.2032e-02, 1.6211e-02],
        [9.9527e-01,

Train progress:  75%|████████████████████▎      | 3/4 [00:12<00:04,  4.07s/it, accuracy=64, loss=89]

tensor([[9.5353e-01, 7.8739e-01, 4.7026e-01, 1.2972e-01, 1.0673e-01],
        [6.6898e-01, 1.6496e-01, 9.4307e-02, 5.4727e-02, 3.1196e-02],
        [9.7474e-01, 8.6594e-01, 6.9841e-01, 4.6717e-01, 3.5430e-01],
        [3.6497e-01, 7.1730e-02, 3.3798e-02, 1.0961e-02, 9.4594e-03],
        [8.4231e-01, 6.0782e-01, 3.2134e-01, 7.1365e-02, 6.0522e-02],
        [9.9085e-01, 9.4759e-01, 8.7072e-01, 7.1553e-01, 4.9305e-01],
        [2.7170e-01, 1.6472e-01, 1.2723e-01, 5.4876e-02, 6.1044e-02],
        [3.6266e-01, 2.5784e-01, 1.2277e-01, 2.1403e-02, 2.5536e-02],
        [9.9984e-01, 9.7968e-01, 8.3888e-01, 5.3070e-01, 2.0960e-01],
        [9.9424e-01, 9.6587e-01, 9.1622e-01, 7.8524e-01, 6.0975e-01],
        [9.9742e-01, 9.5237e-01, 7.0811e-01, 1.7777e-01, 1.0158e-01],
        [1.3784e-01, 2.6337e-02, 1.3766e-02, 3.5894e-03, 3.7350e-03],
        [9.4917e-01, 8.7464e-01, 6.4570e-01, 2.2300e-01, 1.9512e-01],
        [9.7483e-01, 8.3552e-01, 5.8198e-01, 2.5241e-01, 1.7479e-01],
        [1.5057e-01,

Train progress: 100%|█████████████████████████| 4/4 [00:14<00:00,  3.55s/it, accuracy=73, loss=35.5]


Train Confusion matrix :
[[46 25  1  1  0  0]
 [56 84 39 25 11  1]
 [ 5 21 33 33  9  2]
 [ 6 11 26 55 32  7]
 [ 0  0  2 38 48 18]
 [ 0  0  0  9 14 22]]






[INFO] EPOCH: 5/5
Train loss: 82.521561, Train accuracy: 0.1074
Val loss: 151.584747, Val accuracy: 0.2975

[INFO] Network evaluation ...


Test progress: 100%|████████████████████████████████████████| 2/2 [00:02<00:00,  1.05s/it, loss=135]


Confusion matrix :
[[17  0  2  4  0  0]
 [39  0  4  5  0  0]
 [29  0  7  6  1  0]
 [30  0 12  6  0  0]
 [14  0  0  9  0  0]
 [ 6  0  3  7  0  0]]

MS: 0.0000

QWK: 0.1455

MAE: 1.7662

CCR: 0.1493

1-off: 0.4975

[INFO] Total training time: 85.39s



