In [1]:
!pip install monai
!pip install einops

Collecting monai
  Downloading monai-1.3.2-py3-none-any.whl.metadata (10 kB)
Downloading monai-1.3.2-py3-none-any.whl (1.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m13.8 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: monai
Successfully installed monai-1.3.2
Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0


In [2]:
from einops import rearrange
from pathlib import Path
import math
import numpy
import torch
from torch import nn, einsum
from functools import partial
from torch.autograd.function import Function
import numpy as np
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
# Uncomment this line for MNIST training.
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
from tqdm import tqdm, trange
from torch.optim import Adam
from torch.jit import fork, wait
from torch.cuda.amp import autocast, GradScaler
from torch.nn import DataParallel
import torch.nn.functional as F

np.random.seed(42)
torch.manual_seed(42)
EPSILON = 1e-10

In [3]:
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

class FlashAttentionFunction(Function):
    @staticmethod
    @torch.no_grad()
    def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
        """ Algorithm 1 in the v2 paper """

        device = q.device
        max_neg_value = -torch.finfo(q.dtype).max
        qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)

        o = torch.zeros_like(q)
        all_row_sums = torch.zeros((*q.shape[:-1], 1), device = device)
        all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device = device)

        scale = (q.shape[-1] ** -0.5)

        num_row_tiles = math.ceil(q.shape[-2] / q_bucket_size)
        num_col_tiles = math.ceil(k.shape[-2] / k_bucket_size)

        if exists(mask) and mask.ndim == 2:
            mask = rearrange(mask, 'b n -> b 1 1 n')

        if not exists(mask):
            col_masks = (None,) * num_col_tiles
            mask = (col_masks,) * num_row_tiles 
        else:
            mask = ((mask,) * num_row_tiles) if mask.shape[-2] == 1 else mask.split(q_bucket_size, dim = -2)
            mask = tuple(((row_mask,) * num_col_tiles) if row_mask.shape[-1] == 1 else row_mask.split(k_bucket_size, dim = -1) for row_mask in mask)

        row_splits = zip(
            q.split(q_bucket_size, dim = -2),
            o.split(q_bucket_size, dim = -2),
            mask,
            all_row_sums.split(q_bucket_size, dim = -2),
            all_row_maxes.split(q_bucket_size, dim = -2),
        )

        for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
            q_start_index = ind * q_bucket_size - qk_len_diff

            col_splits = zip(
                k.split(k_bucket_size, dim = -2),
                v.split(k_bucket_size, dim = -2),
                row_mask
            )

            for k_ind, (kc, vc, col_mask) in enumerate(col_splits):
                k_start_index = k_ind * k_bucket_size

                attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale

                if exists(col_mask):
                    attn_weights.masked_fill_(~col_mask, max_neg_value)

                if causal and q_start_index < (k_start_index + k_bucket_size - 1):
                    causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
                    attn_weights.masked_fill_(causal_mask, max_neg_value)

                block_row_maxes = attn_weights.amax(dim = -1, keepdims = True)
                new_row_maxes = torch.maximum(block_row_maxes, row_maxes)

                exp_weights = torch.exp(attn_weights - new_row_maxes)

                if exists(col_mask):
                    exp_weights.masked_fill_(~col_mask, 0.)

                block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)

                exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)

                exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)

                new_row_sums = exp_row_max_diff * row_sums + block_row_sums

                oc.mul_(exp_row_max_diff).add_(exp_values)

                row_maxes.copy_(new_row_maxes)
                row_sums.copy_(new_row_sums)

            oc.div_(row_sums)

        lse = all_row_sums.log() + all_row_maxes

        ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
        ctx.save_for_backward(q, k, v, o, lse)

        return o

    @staticmethod
    @torch.no_grad()
    def backward(ctx, do):
        """ Algorithm 2 in the v2 paper """

        causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
        q, k, v, o, lse = ctx.saved_tensors

        device = q.device

        max_neg_value = -torch.finfo(q.dtype).max
        qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)

        dq = torch.zeros_like(q)
        dk = torch.zeros_like(k)
        dv = torch.zeros_like(v)

        row_splits = zip(
            q.split(q_bucket_size, dim = -2),
            o.split(q_bucket_size, dim = -2),
            do.split(q_bucket_size, dim = -2),
            mask,
            lse.split(q_bucket_size, dim = -2),
            dq.split(q_bucket_size, dim = -2)
        )

        for ind, (qc, oc, doc, row_mask, lsec, dqc) in enumerate(row_splits):
            q_start_index = ind * q_bucket_size - qk_len_diff

            col_splits = zip(
                k.split(k_bucket_size, dim = -2),
                v.split(k_bucket_size, dim = -2),
                dk.split(k_bucket_size, dim = -2),
                dv.split(k_bucket_size, dim = -2),
                row_mask
            )

            for k_ind, (kc, vc, dkc, dvc, col_mask) in enumerate(col_splits):
                k_start_index = k_ind * k_bucket_size

                attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale

                if causal and q_start_index < (k_start_index + k_bucket_size - 1):
                    causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
                    attn_weights.masked_fill_(causal_mask, max_neg_value)

                p = torch.exp(attn_weights - lsec)

                if exists(col_mask):
                    p.masked_fill_(~col_mask, 0.)

                dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
                dp = einsum('... i d, ... j d -> ... i j', doc, vc)

                D = (doc * oc).sum(dim = -1, keepdims = True)
                ds = p * scale * (dp - D)

                dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
                dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)

                dqc.add_(dq_chunk)
                dkc.add_(dk_chunk)
                dvc.add_(dv_chunk)

        return dq, dk, dv, None, None, None, None

In [4]:
class FlashAttention(nn.Module):
    def __init__(
        self,
        *,
        dim,
        heads = 8,
        dim_head = 64,
        causal = False,
        q_bucket_size = 512,
        k_bucket_size = 1024,
        parallel = False,
        mixed_precision = False
    ):
        super().__init__()
        self.heads = heads
        self.causal = causal
        self.parallel = parallel
        self.mixed_precision = mixed_precision

        inner_dim = heads * dim_head

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

        # memory efficient attention related parameters
        # can be overriden on forward
        self.q_bucket_size = q_bucket_size
        self.k_bucket_size = k_bucket_size

        if self.parallel:
            self.model = DataParallel(self)
        if self.mixed_precision:
            self.scaler = GradScaler()

    def forward(
        self,
        x,
        context = None,
        mask = None,
        q_bucket_size = None,
        k_bucket_size = None,
    ):
        q_bucket_size = default(q_bucket_size, self.q_bucket_size)
        k_bucket_size = default(k_bucket_size, self.k_bucket_size)

        h = self.heads
        context = default(context, x)

        q = self.to_q(x)
        k, v = self.to_kv(context).chunk(2, dim=-1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))

        if self.parallel:
            # Split the input data into chunks and move each chunk to the correct GPU
            num_gpus = torch.cuda.device_count()
            x_chunks = x.split(x.size(0) // num_gpus)
            x_chunks = [chunk.to(f'cuda:{i}') for i, chunk in enumerate(x_chunks)]
            q = x_chunks

        if self.mixed_precision:
            # Use autocast to allow operations to run in lower precision
            with autocast():
                out = FlashAttentionFunction.apply(q, k, v, mask, self.causal, q_bucket_size, k_bucket_size)
        else:
            out = FlashAttentionFunction.apply(q, k, v, mask, self.causal, q_bucket_size, k_bucket_size)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

In [5]:
class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.size(-1) == self.in_features
        original_shape = x.shape
        x = x.reshape(-1, self.in_features)

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        output = base_output + spline_output

        output = output.reshape(*original_shape[:-1], self.out_features)
        return output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )

In [6]:
class FlashKANViT(torch.nn.Module):
    """
    The workflow will be as follows.
        1. Find the linear mapping of the input
        2. Embed them using the function that we have written
        3. Use 'n' MSA blocks and add a linear and a softmax layer at the end
    """

    def __init__(self, chw, n_patches=16, n_blocks=2, hidden_d=8, n_heads=4, out_d=10):
        super(FlashKANViT, self).__init__()

        self.chw = chw
        self.n_patches = n_patches
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.hidden_d = hidden_d

        # Input and patch sizes
        assert chw[1] % n_patches == 0
        assert chw[2] % n_patches == 0
        self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)

        # Linear mapping
        self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = KANLinear(self.input_d, self.hidden_d)

        # Classification token
        self.v_class = torch.nn.Parameter(torch.rand(1, self.hidden_d))

        # Positional embedding
        self.register_buffer('pos_embeddings', self.positional_embeddings(n_patches ** 2 + 1, hidden_d),
                             persistent=False)

        # Encoder blocks
        self.blocks = torch.nn.ModuleList([FlashAttention(dim = hidden_d, heads = n_heads) for _ in range(n_blocks)])

        self.mlp = torch.nn.Sequential(
            KANLinear(self.hidden_d, out_d),
            torch.nn.Softmax(dim=-1)
        )
        
    def patchify(self, images, n_patches):
        """
        In order to "sequentially" pass in the images, we can break down the main image into multiple sub-images
        and map them to a vector. This is exactly what this function does.

        Arguments:
        images: The image passed into this function
        n_patches: The number of patches to split the image into.

        Returns our patches aka the sub-images.
        """
        n, c, h, w = images.shape

        assert h == w, "Only for square images"

        patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)
        patch_size = h // n_patches

        for idx, image in enumerate(images):
            for i in range(n_patches):
                for j in range(n_patches):
                    patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                    patches[idx, i * n_patches + j] = patch.flatten()
        return patches
    
    def positional_embeddings(self, sequence_length, d):
        """
        In order for the model to know where to place each image, one can use positional embeddings where high freq values
        are classified into the first few dimensions while low frequency values are added on to the latter dimensions. This
        function performs exactly that. It has two parameters.

        Arguments:
        sequence_length: The number of tokens for the dataset.
        d: The dimensionality for each token.

        Returns a matrix where each (i,j) is added as token i in dimension j.
        """
        result = torch.ones(sequence_length, d)
        for i in range(sequence_length):
            for j in range(d):
                result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** (j / d)))
        return result

    def forward(self, images):
        n, c, h, w = images.shape
        patches = self.patchify(images, self.n_patches).to(self.pos_embeddings.device)

        # rutorch.nning tokenization
        tokens = self.linear_mapper(patches)
        tokens = torch.cat((self.v_class.expand(n, 1, -1), tokens), dim=1)
        out = tokens + self.pos_embeddings.repeat(n, 1, 1)

        for block in self.blocks:
            out = block(out)

        out = out[:, 0]
        return self.mlp(out)

In [7]:
lr = 2e-3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mnist_model = FlashKANViT((1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10).to(device)
optimizer = Adam(mnist_model.parameters(), lr=lr)

In [None]:
import os
import torch
import datetime
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import MNIST
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, roc_auc_score
from tqdm import tqdm, trange

def calculate_metrics(y_true, y_pred, y_pred_proba):
    accuracy = accuracy_score(y_true, y_pred)
    balanced_accuracy = balanced_accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average='weighted')

    y_true_bin = torch.nn.functional.one_hot(
        torch.tensor(y_true), num_classes=10).numpy()
    roc_auc = roc_auc_score(y_true_bin, y_pred_proba,
                            average='weighted', multi_class='ovr')

    return accuracy, balanced_accuracy, f1, roc_auc

def save_metrics(filename, epoch, phase, loss, accuracy, balanced_accuracy, f1, roc_auc):
    os.makedirs('logs', exist_ok=True)
    with open(f'logs/{filename}', 'a') as f:
        f.write(f"Epoch: {epoch}, Phase: {phase}\n")
        f.write(f"  Loss: {loss:.4f}\n")
        f.write(f"  Accuracy: {accuracy:.4f}\n")
        f.write(f"  Balanced Accuracy: {balanced_accuracy:.4f}\n")
        f.write(f"  F1 Score: {f1:.4f}\n")
        f.write(f"  ROC AUC: {roc_auc:.4f}\n\n")

def main(train_loader, test_loader, epochs: int):
    print("Using device: ", device,
          f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")
    criterion = torch.nn.CrossEntropyLoss()
    
    # Create a unique filename and TensorBoard writer for this run
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    log_filename = f"flashkan_{epochs}epochs_{timestamp}.txt"
    train_writer = SummaryWriter(log_dir=f"runs/flashkan_{epochs}epochs_{timestamp}/train")
    test_writer = SummaryWriter(log_dir=f"runs/flashkan_{epochs}epochs_{timestamp}/test")

    for epoch in trange(epochs, desc="train"):
        train_loss = 0.0
        y_true_train, y_pred_train, y_pred_proba_train = [], [], []

        mnist_model.train()
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = mnist_model(x)
            loss = criterion(y_hat, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            y_true_train.extend(y.cpu().numpy())
            y_pred_train.extend(torch.argmax(y_hat, dim=1).cpu().numpy())
            y_pred_proba_train.extend(torch.nn.functional.softmax(
                y_hat, dim=1).detach().cpu().numpy())

        # Calculate training metrics
        accuracy, balanced_accuracy, f1, roc_auc = calculate_metrics(
            y_true_train, y_pred_train, y_pred_proba_train)

        # Log training metrics to TensorBoard
        train_writer.add_scalar('Loss/train', train_loss, epoch)
        train_writer.add_scalar('Accuracy/train', accuracy, epoch)
        train_writer.add_scalar('Balanced Accuracy/train', balanced_accuracy, epoch)
        train_writer.add_scalar('F1 Score/train', f1, epoch)
        train_writer.add_scalar('ROC AUC/train', roc_auc, epoch)
        
        print(f"Epoch: {epoch+1}, Phase: Train")
        print(f"  Loss: {train_loss:.4f}")
        print(f"  Accuracy: {accuracy:.4f}")
        print(f"  Balanced Accuracy: {balanced_accuracy:.4f}")
        print(f"  F1 Score: {f1:.4f}")
        print(f"  ROC AUC: {roc_auc:.4f}")

        # Save metrics for the last epoch
        if epoch == epochs - 1:
            save_metrics(log_filename, epoch + 1, "Train",
                         train_loss, accuracy, balanced_accuracy, f1, roc_auc)

    # Testing
    mnist_model.eval()
    with torch.no_grad():
        test_loss = 0.0
        y_true_test, y_pred_test, y_pred_proba_test = [], [], []

        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = mnist_model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            y_true_test.extend(y.cpu().numpy())
            y_pred_test.extend(torch.argmax(y_hat, dim=1).cpu().numpy())
            y_pred_proba_test.extend(
                torch.nn.functional.softmax(y_hat, dim=1).cpu().numpy())

        # Calculate test metrics
        accuracy, balanced_accuracy, f1, roc_auc = calculate_metrics(
            y_true_test, y_pred_test, y_pred_proba_test)

        # Log test metrics to TensorBoard
        test_writer.add_scalar('Loss/test', test_loss, epochs)
        test_writer.add_scalar('Accuracy/test', accuracy, epochs)
        test_writer.add_scalar('Balanced Accuracy/test', balanced_accuracy, epochs)
        test_writer.add_scalar('F1 Score/test', f1, epochs)
        test_writer.add_scalar('ROC AUC/test', roc_auc, epochs)
        
        print(f"Epoch: {epochs}, Phase: Train")
        print(f"  Loss: {test_loss:.4f}")
        print(f"  Accuracy: {accuracy:.4f}")
        print(f"  Balanced Accuracy: {balanced_accuracy:.4f}")
        print(f"  F1 Score: {f1:.4f}")
        print(f"  ROC AUC: {roc_auc:.4f}")

        # Save test metrics
        save_metrics(log_filename, epochs, "Test", test_loss,
                     accuracy, balanced_accuracy, f1, roc_auc)

    # Close the TensorBoard writer
    writer.close()


In [None]:
transform = transforms.ToTensor()
train_mnist = MNIST(root='./cifar', train=True, download=True, transform=transform)
test_mnist = MNIST(root='./cifar', train=False, download=True, transform=transform)
train_loader = DataLoader(train_mnist, shuffle=True, batch_size=128)
test_loader = DataLoader(test_mnist, shuffle=False, batch_size=128)
main(train_loader=train_loader, test_loader=test_loader, epochs=10)

Using device:  cuda (Tesla P100-PCIE-16GB)


train:   0%|          | 0/8 [00:00<?, ?it/s]
Epoch 1 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 1 in training:   0%|          | 1/469 [00:01<08:14,  1.06s/it][A
Epoch 1 in training:   0%|          | 2/469 [00:01<05:08,  1.51it/s][A
Epoch 1 in training:   1%|          | 3/469 [00:01<04:08,  1.87it/s][A
Epoch 1 in training:   1%|          | 4/469 [00:02<03:41,  2.10it/s][A
Epoch 1 in training:   1%|          | 5/469 [00:02<03:27,  2.24it/s][A
Epoch 1 in training:   1%|▏         | 6/469 [00:02<03:16,  2.35it/s][A
Epoch 1 in training:   1%|▏         | 7/469 [00:03<03:11,  2.41it/s][A
Epoch 1 in training:   2%|▏         | 8/469 [00:03<03:08,  2.45it/s][A
Epoch 1 in training:   2%|▏         | 9/469 [00:04<03:05,  2.48it/s][A
Epoch 1 in training:   2%|▏         | 10/469 [00:04<03:02,  2.52it/s][A
Epoch 1 in training:   2%|▏         | 11/469 [00:04<03:07,  2.45it/s][A
Epoch 1 in training:   3%|▎         | 12/469 [00:05<03:08,  2.43it/s][A
Epoch 1 in training:   3

Epoch 1/8 loss: 2.17



Epoch 2 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 2 in training:   0%|          | 1/469 [00:00<03:09,  2.47it/s][A
Epoch 2 in training:   0%|          | 2/469 [00:00<03:06,  2.50it/s][A
Epoch 2 in training:   1%|          | 3/469 [00:01<03:04,  2.52it/s][A
Epoch 2 in training:   1%|          | 4/469 [00:01<03:03,  2.54it/s][A
Epoch 2 in training:   1%|          | 5/469 [00:01<03:01,  2.55it/s][A
Epoch 2 in training:   1%|▏         | 6/469 [00:02<03:02,  2.54it/s][A
Epoch 2 in training:   1%|▏         | 7/469 [00:02<03:01,  2.55it/s][A
Epoch 2 in training:   2%|▏         | 8/469 [00:03<03:01,  2.55it/s][A
Epoch 2 in training:   2%|▏         | 9/469 [00:03<03:08,  2.44it/s][A
Epoch 2 in training:   2%|▏         | 10/469 [00:03<03:05,  2.47it/s][A
Epoch 2 in training:   2%|▏         | 11/469 [00:04<03:04,  2.48it/s][A
Epoch 2 in training:   3%|▎         | 12/469 [00:04<03:03,  2.49it/s][A
Epoch 2 in training:   3%|▎         | 13/469 [00:05<03:01,  2.51it/s

Epoch 2/8 loss: 2.12



Epoch 3 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 3 in training:   0%|          | 1/469 [00:00<03:04,  2.54it/s][A
Epoch 3 in training:   0%|          | 2/469 [00:00<03:03,  2.54it/s][A
Epoch 3 in training:   1%|          | 3/469 [00:01<03:10,  2.44it/s][A
Epoch 3 in training:   1%|          | 4/469 [00:01<03:09,  2.45it/s][A
Epoch 3 in training:   1%|          | 5/469 [00:01<03:03,  2.52it/s][A
Epoch 3 in training:   1%|▏         | 6/469 [00:02<03:00,  2.57it/s][A
Epoch 3 in training:   1%|▏         | 7/469 [00:02<02:59,  2.58it/s][A
Epoch 3 in training:   2%|▏         | 8/469 [00:03<03:00,  2.56it/s][A
Epoch 3 in training:   2%|▏         | 9/469 [00:03<02:58,  2.58it/s][A
Epoch 3 in training:   2%|▏         | 10/469 [00:03<02:56,  2.60it/s][A
Epoch 3 in training:   2%|▏         | 11/469 [00:04<02:56,  2.60it/s][A
Epoch 3 in training:   3%|▎         | 12/469 [00:04<02:54,  2.62it/s][A
Epoch 3 in training:   3%|▎         | 13/469 [00:05<02:54,  2.61it/s

Epoch 3/8 loss: 2.08



Epoch 4 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 4 in training:   0%|          | 1/469 [00:00<03:05,  2.52it/s][A
Epoch 4 in training:   0%|          | 2/469 [00:00<03:02,  2.56it/s][A
Epoch 4 in training:   1%|          | 3/469 [00:01<03:03,  2.54it/s][A
Epoch 4 in training:   1%|          | 4/469 [00:01<03:02,  2.55it/s][A
Epoch 4 in training:   1%|          | 5/469 [00:01<03:01,  2.55it/s][A
Epoch 4 in training:   1%|▏         | 6/469 [00:02<03:01,  2.56it/s][A
Epoch 4 in training:   1%|▏         | 7/469 [00:02<03:00,  2.57it/s][A
Epoch 4 in training:   2%|▏         | 8/469 [00:03<02:58,  2.58it/s][A
Epoch 4 in training:   2%|▏         | 9/469 [00:03<02:57,  2.59it/s][A
Epoch 4 in training:   2%|▏         | 10/469 [00:03<02:57,  2.59it/s][A
Epoch 4 in training:   2%|▏         | 11/469 [00:04<02:57,  2.59it/s][A
Epoch 4 in training:   3%|▎         | 12/469 [00:04<02:57,  2.57it/s][A
Epoch 4 in training:   3%|▎         | 13/469 [00:05<02:59,  2.55it/s

Epoch 4/8 loss: 2.01



Epoch 5 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 5 in training:   0%|          | 1/469 [00:00<02:57,  2.64it/s][A
Epoch 5 in training:   0%|          | 2/469 [00:00<02:59,  2.61it/s][A
Epoch 5 in training:   1%|          | 3/469 [00:01<02:58,  2.61it/s][A
Epoch 5 in training:   1%|          | 4/469 [00:01<02:57,  2.62it/s][A
Epoch 5 in training:   1%|          | 5/469 [00:01<02:57,  2.61it/s][A
Epoch 5 in training:   1%|▏         | 6/469 [00:02<02:56,  2.63it/s][A
Epoch 5 in training:   1%|▏         | 7/469 [00:02<02:56,  2.61it/s][A
Epoch 5 in training:   2%|▏         | 8/469 [00:03<02:55,  2.62it/s][A
Epoch 5 in training:   2%|▏         | 9/469 [00:03<02:54,  2.63it/s][A
Epoch 5 in training:   2%|▏         | 10/469 [00:03<02:54,  2.63it/s][A
Epoch 5 in training:   2%|▏         | 11/469 [00:04<02:54,  2.63it/s][A
Epoch 5 in training:   3%|▎         | 12/469 [00:04<02:55,  2.60it/s][A
Epoch 5 in training:   3%|▎         | 13/469 [00:04<02:55,  2.60it/s

Epoch 5/8 loss: 1.97



Epoch 6 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 6 in training:   0%|          | 1/469 [00:00<03:01,  2.58it/s][A
Epoch 6 in training:   0%|          | 2/469 [00:00<03:01,  2.57it/s][A
Epoch 6 in training:   1%|          | 3/469 [00:01<03:02,  2.55it/s][A
Epoch 6 in training:   1%|          | 4/469 [00:01<03:01,  2.56it/s][A
Epoch 6 in training:   1%|          | 5/469 [00:01<03:00,  2.57it/s][A
Epoch 6 in training:   1%|▏         | 6/469 [00:02<03:01,  2.55it/s][A
Epoch 6 in training:   1%|▏         | 7/469 [00:02<02:59,  2.57it/s][A
Epoch 6 in training:   2%|▏         | 8/469 [00:03<03:03,  2.52it/s][A
Epoch 6 in training:   2%|▏         | 9/469 [00:03<03:02,  2.53it/s][A
Epoch 6 in training:   2%|▏         | 10/469 [00:03<03:00,  2.54it/s][A
Epoch 6 in training:   2%|▏         | 11/469 [00:04<03:00,  2.54it/s][A
Epoch 6 in training:   3%|▎         | 12/469 [00:04<03:00,  2.53it/s][A
Epoch 6 in training:   3%|▎         | 13/469 [00:05<02:59,  2.54it/s

Epoch 6/8 loss: 1.95



Epoch 7 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 7 in training:   0%|          | 1/469 [00:00<03:05,  2.52it/s][A
Epoch 7 in training:   0%|          | 2/469 [00:00<03:03,  2.54it/s][A
Epoch 7 in training:   1%|          | 3/469 [00:01<03:03,  2.54it/s][A
Epoch 7 in training:   1%|          | 4/469 [00:01<03:01,  2.56it/s][A
Epoch 7 in training:   1%|          | 5/469 [00:01<03:01,  2.56it/s][A
Epoch 7 in training:   1%|▏         | 6/469 [00:02<03:02,  2.53it/s][A
Epoch 7 in training:   1%|▏         | 7/469 [00:02<03:02,  2.53it/s][A
Epoch 7 in training:   2%|▏         | 8/469 [00:03<03:01,  2.55it/s][A
Epoch 7 in training:   2%|▏         | 9/469 [00:03<03:00,  2.55it/s][A
Epoch 7 in training:   2%|▏         | 10/469 [00:03<02:58,  2.57it/s][A
Epoch 7 in training:   2%|▏         | 11/469 [00:04<02:58,  2.57it/s][A
Epoch 7 in training:   3%|▎         | 12/469 [00:04<02:57,  2.57it/s][A
Epoch 7 in training:   3%|▎         | 13/469 [00:05<02:57,  2.57it/s

Epoch 7/8 loss: 1.92



Epoch 8 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 8 in training:   0%|          | 1/469 [00:00<03:03,  2.55it/s][A
Epoch 8 in training:   0%|          | 2/469 [00:00<03:10,  2.45it/s][A
Epoch 8 in training:   1%|          | 3/469 [00:01<03:10,  2.45it/s][A
Epoch 8 in training:   1%|          | 4/469 [00:01<03:06,  2.49it/s][A
Epoch 8 in training:   1%|          | 5/469 [00:01<03:02,  2.54it/s][A
Epoch 8 in training:   1%|▏         | 6/469 [00:02<02:59,  2.58it/s][A
Epoch 8 in training:   1%|▏         | 7/469 [00:02<02:57,  2.60it/s][A
Epoch 8 in training:   2%|▏         | 8/469 [00:03<02:56,  2.62it/s][A
Epoch 8 in training:   2%|▏         | 9/469 [00:03<02:55,  2.61it/s][A
Epoch 8 in training:   2%|▏         | 10/469 [00:03<02:53,  2.64it/s][A
Epoch 8 in training:   2%|▏         | 11/469 [00:04<02:55,  2.61it/s][A
Epoch 8 in training:   3%|▎         | 12/469 [00:04<03:00,  2.53it/s][A
Epoch 8 in training:   3%|▎         | 13/469 [00:05<02:57,  2.57it/s

Epoch 8/8 loss: 1.88


Testing: 100%|██████████| 79/79 [00:29<00:00,  2.69it/s]

Test loss: 1.85
Test accuracy: 61.63%





In [10]:
path: str = "flashkan_vit_8epochs.pth"
torch.save(mnist_model.state_dict(), path)