# Install

In [1]:
!pip install torch torchvision

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
Collecting nvidia-curand-cu12==10.3.2.106 (from torch)
  Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)
Collectin

# Import

In [2]:
import torch
import torch.nn.functional as F
from torch import nn

import torch.func as fc
from torch import Tensor

In [3]:
from typing import List, Optional
from typing import Dict, KeysView, ValuesView

In [4]:
import torchvision
import matplotlib.pyplot as plt
import torchvision.transforms as tt

In [5]:
import math
import time
from tqdm import tqdm
import copy
from functools import partial

# Constants

In [6]:
BATCH_SIZE =  64
SHUFFLE = True
EPOCHS = 50

# Model definitions

In [7]:
class NeuralNet(nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_sizes: List[int],
        output_size: int,
        activation_function: Optional[torch.nn.Module] = None,
    ):
        """Standard Fully-Connnected layers.

        Args:
            input_size (int): input size of the model.
            hidden_sizes (List[int]): a list of hidden sizes.
            output_size (int): The number of output classes.
            activation_function (Optional[torch.nn.Module], optional): the activation function for the hidden layers.
                Defaults to None.

        """
        super(NeuralNet, self).__init__()
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.hidden_sizes.insert(0, input_size)
        for i in range(len(hidden_sizes) - 1):
            setattr(self, f"fc{i}", nn.Linear(hidden_sizes[i], hidden_sizes[i + 1]))
            setattr(self, f"act{i}", activation_function or nn.ReLU())
        self.out = nn.Linear(hidden_sizes[-1], output_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for i in range(len(self.hidden_sizes) - 1):
            x = getattr(self, f"act{i}")(getattr(self, f"fc{i}")(x))
        x = self.out(x)
        return x


# Loss Definition

In [8]:
def cross_entropy(model: torch.nn.Module, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    """Cross-entropy loss. Given a pytorch model, it computes the cross-entropy loss.

    Args:
        model (torch.nn.Module): PyTorch model.
        x (torch.Tensor): Input tensor for the PyTorch model.
        t (torch.Tensor): Targets.

    Returns:
        torch.Tensor: Cross-entropy loss.
    """
    y = model(x)
    return F.cross_entropy(y, t)


def functional_cross_entropy(
    params: ValuesView,
    buffers: Dict[str, Tensor],
    names: KeysView,
    model: torch.nn.Module,
    x: torch.Tensor,
    t: torch.Tensor,
) -> torch.Tensor:
    """Functional cross-entropy loss. Given a pytorch model it computes the cross-entropy loss
    in a functional way.

    Args:
        params: Model parameters.
        buffers: Buffers of the model.
        names: Names of the parameters.
        model: A pytorch model.
        x (torch.Tensor): Input tensor for the PyTorch model.
        t (torch.Tensor): Targets.

    Returns:
        torch.Tensor: Cross-entropy loss.
    """
    y = fc.functional_call(model, ({k: v for k, v in zip(names, params)}, buffers), (x,))
    return F.cross_entropy(y, t)

# Utils

In [9]:
def exponential_lr_decay(step: int, k: float=1e-4):
    return math.e ** (-step * k)

# Get the data

In [10]:
transform = [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))]
transform.append(torchvision.transforms.Lambda(torch.flatten))
mnist_train = torchvision.datasets.MNIST(
            "./data",
            train=True,
            download=True,
            transform=torchvision.transforms.Compose(transform),
)
mnist_test = torchvision.datasets.MNIST(
            "./data",
            train=False,
            download=True,
            transform=torchvision.transforms.Compose(transform),
        )
input_size = mnist_train.data.shape[1] * mnist_train.data.shape[2]
output_size = len(mnist_train.classes)

train_dataLoader = torch.utils.data.DataLoader(mnist_train, batch_size=BATCH_SIZE,shuffle=SHUFFLE)
test_dataLoader = torch.utils.data.DataLoader(mnist_test,batch_size=BATCH_SIZE,shuffle=SHUFFLE)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 18214150.61it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 498355.19it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4488614.66it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 10892240.58it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



# Train the model

In [11]:
def train_model_fwrd(train_loader, test_loader, total_epochs, input_size, output_size):
  use_cuda = torch.cuda.is_available()
  device = torch.device("cuda:0" if use_cuda else "cpu")

  with torch.no_grad():
        model = NeuralNet(input_size=input_size,output_size=output_size, hidden_sizes=[1024, 1024])
        model.to(device)
        model.float()
        model.train()

        optimizer: torch.optim.Optimizer = torch.optim.SGD(params=model.parameters(), lr=2e-4, nesterov=False, momentum=0.0, weight_decay=0.0)
        optimizer.zero_grad(set_to_none=True)

        scheduler: torch.optim.lr_scheduler._LRScheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=exponential_lr_decay)

        named_buffers = dict(model.named_buffers())
        named_params = dict(model.named_parameters())
        names = named_params.keys()
        params = named_params.values()

        base_model = copy.deepcopy(model)
        base_model.to("meta")

        # Train
        steps = 0
        t_total = 0.0

        for epoch in range(total_epochs):
            t0 = time.perf_counter()
            with tqdm(total=len(train_loader)) as pbar:
              for batch in train_loader:
                  pbar.update(1)
                  steps += 1
                  images, labels = batch

                  # Sample perturbation (tangent) vectors for every parameter of the model
                  v_params = tuple([torch.randn_like(p) for p in params])
                  f = partial(
                      functional_cross_entropy,
                      model=base_model,
                      names=names,
                      buffers=named_buffers,
                      x=images.to(device),
                      t=labels.to(device),
                  )

                  # Forward AD
                  loss, jvp = fc.jvp(f, (tuple(params),), (v_params,))

                  # Setting gradients
                  for v, p in zip(v_params, params):
                      p.grad = v * jvp

                  # Optimizer step
                  optimizer.step()

                  # Lr scaling
                  scheduler.step()

                  # Zero out grads
                  optimizer.zero_grad(set_to_none=True)

            t1 = time.perf_counter()
            t_total += t1 - t0
            print("Time/batch_time", t1 - t0, steps)
            print("Time/sps", steps / t_total, steps)

            acc = 0
            for batch in test_loader:
                images, labels = batch
                out = fc.functional_call(base_model, (named_params, named_buffers), (images.to(device),))
                pred = F.softmax(out, dim=-1).argmax(dim=-1)
                acc += (pred == labels.to(device)).sum()
            print(f"Epoch [{epoch+1}/{total_epochs}], Loss: {loss.item():.4f}, Time (s): {t1 - t0:.4f}, Test accuracy: {(acc / len(mnist_test)).item():.4f}")
        print(f"Mean time: {t_total / total_epochs:.4f}")

        # Test
        acc = 0
        for batch in test_loader:
            images, labels = batch
            out = fc.functional_call(base_model, (named_params, named_buffers), (images.to(device),))
            pred = F.softmax(out, dim=-1).argmax(dim=-1)
            acc += (pred == labels.to(device)).sum()
        print("Test/accuracy", acc / len(mnist_test), steps)
        print(f"Test accuracy: {(acc / len(mnist_test)).item():.4f}")

In [12]:
def train_model_bkrd(train_loader, test_loader, total_epochs, input_size, output_size):
  use_cuda = torch.cuda.is_available()
  device = torch.device("cuda:0" if use_cuda else "cpu")

  model = NeuralNet(input_size=input_size,output_size=output_size, hidden_sizes=[1024, 1024])
  model.to(device)
  model.float()
  model.train()
  params = model.parameters()

  optimizer: torch.optim.Optimizer = torch.optim.SGD(params=model.parameters(), lr=2e-4, nesterov=False, momentum=0.0, weight_decay=0.0)
  optimizer.zero_grad(set_to_none=True)

  scheduler: torch.optim.lr_scheduler._LRScheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=exponential_lr_decay)


  steps = 0
  t_total = 0.0
  for epoch in range(total_epochs):
      t0 = time.perf_counter()
      with tqdm(total=len(train_loader)) as pbar:
        for batch in train_loader:
            pbar.update(1)
            steps += 1
            images, labels = batch
            loss = cross_entropy(model, images.to(device), labels.to(device))
            loss.backward()

            optimizer.step()
            scheduler.step()
            optimizer.zero_grad(set_to_none=True)

      t1 = time.perf_counter()
      t_total += t1 - t0
      print("Time/batch_time", t1 - t0, steps)
      print("Time/sps", steps / t_total, steps)
      # Test
      acc = 0
      for batch in test_loader:
          images, labels = batch
          out = model(images.to(device))
          pred = F.softmax(out, dim=-1).argmax(dim=-1)
          acc += (pred == labels.to(device)).sum()
      print(f"Epoch [{epoch+1}/{total_epochs}], Loss: {loss.item():.4f}, Time (s): {t1 - t0:.4f}, Test accuracy: {(acc / len(mnist_test)).item():.4f}")
  print("Mean time:", t_total / total_epochs)

    # Test
  acc = 0
  for batch in test_loader:
      images, labels = batch
      out = model(images.to(device))
      pred = F.softmax(out, dim=-1).argmax(dim=-1)
      acc += (pred == labels.to(device)).sum()
  print("Test/accuracy", acc / len(mnist_test), steps)
  print(f"Test accuracy: {(acc / len(mnist_test)).item():.4f}")



# Train Backward

In [13]:
train_model_bkrd(train_dataLoader, test_dataLoader, total_epochs=EPOCHS, input_size=input_size, output_size=output_size)

100%|██████████| 938/938 [00:19<00:00, 47.35it/s]


Time/batch_time 19.816051845999993 938
Time/sps 47.33536262872373 938
Epoch [1/50], Loss: 2.1495, Time (s): 19.8161, Test accuracy: 0.5586


100%|██████████| 938/938 [00:15<00:00, 59.29it/s]


Time/batch_time 15.826198306999999 1876
Time/sps 52.63416288104633 1876
Epoch [2/50], Loss: 2.0692, Time (s): 15.8262, Test accuracy: 0.6593


100%|██████████| 938/938 [00:14<00:00, 62.81it/s]


Time/batch_time 14.939696803000004 2814
Time/sps 55.63249675714994 2814
Epoch [3/50], Loss: 1.9742, Time (s): 14.9397, Test accuracy: 0.7023


100%|██████████| 938/938 [00:15<00:00, 61.00it/s]


Time/batch_time 15.383464024000006 3752
Time/sps 56.878293400424134 3752
Epoch [4/50], Loss: 1.8465, Time (s): 15.3835, Test accuracy: 0.7223


100%|██████████| 938/938 [00:15<00:00, 62.26it/s]


Time/batch_time 15.07057319399999 4690
Time/sps 57.87552342092445 4690
Epoch [5/50], Loss: 1.7991, Time (s): 15.0706, Test accuracy: 0.7405


100%|██████████| 938/938 [00:14<00:00, 62.54it/s]


Time/batch_time 15.003587546000006 5628
Time/sps 58.6008444145111 5628
Epoch [6/50], Loss: 1.4515, Time (s): 15.0036, Test accuracy: 0.7534


100%|██████████| 938/938 [00:15<00:00, 59.93it/s]


Time/batch_time 15.659526577999998 6566
Time/sps 58.782927526260664 6566
Epoch [7/50], Loss: 1.4336, Time (s): 15.6595, Test accuracy: 0.7664


100%|██████████| 938/938 [00:14<00:00, 62.86it/s]


Time/batch_time 14.927048575000015 7504
Time/sps 59.26106246860812 7504
Epoch [8/50], Loss: 1.3581, Time (s): 14.9270, Test accuracy: 0.7782


100%|██████████| 938/938 [00:15<00:00, 61.86it/s]


Time/batch_time 15.168439903999968 8442
Time/sps 59.53682853405902 8442
Epoch [9/50], Loss: 1.1648, Time (s): 15.1684, Test accuracy: 0.7887


100%|██████████| 938/938 [00:15<00:00, 59.45it/s]


Time/batch_time 15.782460963000005 9380
Time/sps 59.52643569942289 9380
Epoch [10/50], Loss: 1.1534, Time (s): 15.7825, Test accuracy: 0.7965


100%|██████████| 938/938 [00:15<00:00, 60.86it/s]


Time/batch_time 15.419041567000022 10318
Time/sps 59.64296673602609 10318
Epoch [11/50], Loss: 1.1786, Time (s): 15.4190, Test accuracy: 0.8035


100%|██████████| 938/938 [00:15<00:00, 62.44it/s]


Time/batch_time 15.026981704000036 11256
Time/sps 59.864993904612284 11256
Epoch [12/50], Loss: 1.0818, Time (s): 15.0270, Test accuracy: 0.8103


100%|██████████| 938/938 [00:15<00:00, 60.28it/s]


Time/batch_time 15.56903681 12194
Time/sps 59.89426668110864 12194
Epoch [13/50], Loss: 1.1251, Time (s): 15.5690, Test accuracy: 0.8158


100%|██████████| 938/938 [00:14<00:00, 63.40it/s]


Time/batch_time 14.799502528000005 13132
Time/sps 60.130514991003764 13132
Epoch [14/50], Loss: 1.1423, Time (s): 14.7995, Test accuracy: 0.8201


100%|██████████| 938/938 [00:14<00:00, 63.13it/s]


Time/batch_time 14.863818468999966 14070
Time/sps 60.320139476703304 14070
Epoch [15/50], Loss: 1.1092, Time (s): 14.8638, Test accuracy: 0.8251


100%|██████████| 938/938 [00:15<00:00, 61.65it/s]


Time/batch_time 15.221953571999961 15008
Time/sps 60.39986358373679 15008
Epoch [16/50], Loss: 0.9791, Time (s): 15.2220, Test accuracy: 0.8279


100%|██████████| 938/938 [00:14<00:00, 63.04it/s]


Time/batch_time 14.884376439999983 15946
Time/sps 60.54789454186909 15946
Epoch [17/50], Loss: 0.9531, Time (s): 14.8844, Test accuracy: 0.8299


100%|██████████| 938/938 [00:15<00:00, 60.81it/s]


Time/batch_time 15.432254613999987 16884
Time/sps 60.56084128718714 16884
Epoch [18/50], Loss: 0.9752, Time (s): 15.4323, Test accuracy: 0.8328


100%|██████████| 938/938 [00:14<00:00, 63.22it/s]


Time/batch_time 14.841948901000023 17822
Time/sps 60.69420059338815 17822
Epoch [19/50], Loss: 0.8040, Time (s): 14.8419, Test accuracy: 0.8337


100%|██████████| 938/938 [00:14<00:00, 64.22it/s]


Time/batch_time 14.611573433999979 18760
Time/sps 60.86017833878192 18760
Epoch [20/50], Loss: 0.7801, Time (s): 14.6116, Test accuracy: 0.8349


100%|██████████| 938/938 [00:14<00:00, 63.82it/s]


Time/batch_time 14.703109631000018 19698
Time/sps 60.99384001847256 19698
Epoch [21/50], Loss: 1.0224, Time (s): 14.7031, Test accuracy: 0.8371


100%|██████████| 938/938 [00:14<00:00, 63.78it/s]


Time/batch_time 14.710586070999966 20636
Time/sps 61.11450790334862 20636
Epoch [22/50], Loss: 0.8386, Time (s): 14.7106, Test accuracy: 0.8389


100%|██████████| 938/938 [00:15<00:00, 60.65it/s]


Time/batch_time 15.470790757000032 21574
Time/sps 61.093298373999616 21574
Epoch [23/50], Loss: 0.8452, Time (s): 15.4708, Test accuracy: 0.8406


100%|██████████| 938/938 [00:14<00:00, 64.07it/s]


Time/batch_time 14.647891679000054 22512
Time/sps 61.21052060793203 22512
Epoch [24/50], Loss: 0.9140, Time (s): 14.6479, Test accuracy: 0.8410


100%|██████████| 938/938 [00:14<00:00, 63.04it/s]


Time/batch_time 14.886846135000042 23450
Time/sps 61.28047285009393 23450
Epoch [25/50], Loss: 0.8367, Time (s): 14.8868, Test accuracy: 0.8426


100%|██████████| 938/938 [00:14<00:00, 63.48it/s]


Time/batch_time 14.781101990000025 24388
Time/sps 61.361507581648056 24388
Epoch [26/50], Loss: 0.8689, Time (s): 14.7811, Test accuracy: 0.8435


100%|██████████| 938/938 [00:14<00:00, 63.78it/s]


Time/batch_time 14.712858517000086 25326
Time/sps 61.446903444974495 25326
Epoch [27/50], Loss: 0.7359, Time (s): 14.7129, Test accuracy: 0.8446


100%|██████████| 938/938 [00:15<00:00, 61.06it/s]


Time/batch_time 15.368132991000039 26264
Time/sps 61.43211103043921 26264
Epoch [28/50], Loss: 0.6022, Time (s): 15.3681, Test accuracy: 0.8447


100%|██████████| 938/938 [00:14<00:00, 63.12it/s]


Time/batch_time 14.867387490999931 27202
Time/sps 61.48786421871742 27202
Epoch [29/50], Loss: 0.7287, Time (s): 14.8674, Test accuracy: 0.8453


100%|██████████| 938/938 [00:15<00:00, 61.78it/s]


Time/batch_time 15.18996538500005 28140
Time/sps 61.49660895547719 28140
Epoch [30/50], Loss: 0.9131, Time (s): 15.1900, Test accuracy: 0.8456


100%|██████████| 938/938 [00:14<00:00, 63.74it/s]


Time/batch_time 14.720797917000027 29078
Time/sps 61.56588772264509 29078
Epoch [31/50], Loss: 0.7878, Time (s): 14.7208, Test accuracy: 0.8461


100%|██████████| 938/938 [00:14<00:00, 64.14it/s]


Time/batch_time 14.62962331999995 30016
Time/sps 61.64251833767066 30016
Epoch [32/50], Loss: 0.7669, Time (s): 14.6296, Test accuracy: 0.8465


100%|██████████| 938/938 [00:15<00:00, 61.73it/s]


Time/batch_time 15.20398615900001 30954
Time/sps 61.644087627958 30954
Epoch [33/50], Loss: 0.9524, Time (s): 15.2040, Test accuracy: 0.8468


100%|██████████| 938/938 [00:14<00:00, 63.32it/s]


Time/batch_time 14.818071816000042 31892
Time/sps 61.691583652648674 31892
Epoch [34/50], Loss: 0.6938, Time (s): 14.8181, Test accuracy: 0.8470


100%|██████████| 938/938 [00:15<00:00, 62.15it/s]


Time/batch_time 15.097859258000085 32830
Time/sps 61.70396796155791 32830
Epoch [35/50], Loss: 0.7275, Time (s): 15.0979, Test accuracy: 0.8473


100%|██████████| 938/938 [00:14<00:00, 63.47it/s]


Time/batch_time 14.783429130999934 33768
Time/sps 61.751155006698 33768
Epoch [36/50], Loss: 0.7051, Time (s): 14.7834, Test accuracy: 0.8479


100%|██████████| 938/938 [00:14<00:00, 63.16it/s]


Time/batch_time 14.855866116000016 34706
Time/sps 61.78788860384708 34706
Epoch [37/50], Loss: 0.5711, Time (s): 14.8559, Test accuracy: 0.8483


100%|██████████| 938/938 [00:15<00:00, 62.32it/s]


Time/batch_time 15.057213058999992 35644
Time/sps 61.80114661225062 35644
Epoch [38/50], Loss: 0.8830, Time (s): 15.0572, Test accuracy: 0.8486


100%|██████████| 938/938 [00:14<00:00, 63.68it/s]


Time/batch_time 14.735328048000042 36582
Time/sps 61.84736870518195 36582
Epoch [39/50], Loss: 0.9598, Time (s): 14.7353, Test accuracy: 0.8492


100%|██████████| 938/938 [00:15<00:00, 60.59it/s]


Time/batch_time 15.487436805000016 37520
Time/sps 61.81465370437571 37520
Epoch [40/50], Loss: 0.6544, Time (s): 15.4874, Test accuracy: 0.8491


100%|██████████| 938/938 [00:14<00:00, 63.60it/s]


Time/batch_time 14.755165173000023 38458
Time/sps 61.85633503469254 38458
Epoch [41/50], Loss: 0.8543, Time (s): 14.7552, Test accuracy: 0.8494


100%|██████████| 938/938 [00:14<00:00, 63.58it/s]


Time/batch_time 14.75957232099995 39396
Time/sps 61.895655256061524 39396
Epoch [42/50], Loss: 0.7378, Time (s): 14.7596, Test accuracy: 0.8495


100%|██████████| 938/938 [00:14<00:00, 63.25it/s]


Time/batch_time 14.84039795800004 40334
Time/sps 61.92550773585567 40334
Epoch [43/50], Loss: 0.6202, Time (s): 14.8404, Test accuracy: 0.8496


100%|██████████| 938/938 [00:14<00:00, 63.14it/s]


Time/batch_time 14.859938518000035 41272
Time/sps 61.95221293786642 41272
Epoch [44/50], Loss: 0.8480, Time (s): 14.8599, Test accuracy: 0.8498


100%|██████████| 938/938 [00:15<00:00, 60.42it/s]


Time/batch_time 15.530941273999929 42210
Time/sps 61.91674952778998 42210
Epoch [45/50], Loss: 0.8032, Time (s): 15.5309, Test accuracy: 0.8501


100%|██████████| 938/938 [00:15<00:00, 62.41it/s]


Time/batch_time 15.034835154000007 43148
Time/sps 61.92692794143031 43148
Epoch [46/50], Loss: 0.7123, Time (s): 15.0348, Test accuracy: 0.8505


100%|██████████| 938/938 [00:15<00:00, 61.89it/s]


Time/batch_time 15.163935704999972 44086
Time/sps 61.925444694975795 44086
Epoch [47/50], Loss: 0.9023, Time (s): 15.1639, Test accuracy: 0.8505


100%|██████████| 938/938 [00:14<00:00, 62.91it/s]


Time/batch_time 14.915769769999997 45024
Time/sps 61.945166223147815 45024
Epoch [48/50], Loss: 0.6883, Time (s): 14.9158, Test accuracy: 0.8505


100%|██████████| 938/938 [00:14<00:00, 62.62it/s]


Time/batch_time 14.98428550799997 45962
Time/sps 61.95837149499787 45962
Epoch [49/50], Loss: 0.8356, Time (s): 14.9843, Test accuracy: 0.8506


100%|██████████| 938/938 [00:15<00:00, 60.56it/s]


Time/batch_time 15.49578838299999 46900
Time/sps 61.929197675644794 46900
Epoch [50/50], Loss: 0.9392, Time (s): 15.4958, Test accuracy: 0.8506
Mean time: 15.146328956380005
Test/accuracy tensor(0.8506, device='cuda:0') 46900
Test accuracy: 0.8506


# Train Forward

In [14]:
train_model_fwrd(train_dataLoader, test_dataLoader, total_epochs=EPOCHS, input_size=input_size, output_size=output_size)

100%|██████████| 938/938 [00:16<00:00, 58.11it/s]


Time/batch_time 16.14793789600003 938
Time/sps 58.08791227964469 938
Epoch [1/50], Loss: 2.1591, Time (s): 16.1479, Test accuracy: 0.5174


100%|██████████| 938/938 [00:15<00:00, 59.91it/s]


Time/batch_time 15.662818399999878 1876
Time/sps 58.973762916661634 1876
Epoch [2/50], Loss: 2.0900, Time (s): 15.6628, Test accuracy: 0.6135


100%|██████████| 938/938 [00:15<00:00, 59.92it/s]


Time/batch_time 15.660543994999898 2814
Time/sps 59.27792124399661 2814
Epoch [3/50], Loss: 1.9513, Time (s): 15.6605, Test accuracy: 0.6513


100%|██████████| 938/938 [00:15<00:00, 60.35it/s]


Time/batch_time 15.550012401999993 3752
Time/sps 59.53541492030457 3752
Epoch [4/50], Loss: 1.6220, Time (s): 15.5500, Test accuracy: 0.7258


100%|██████████| 938/938 [00:15<00:00, 59.66it/s]


Time/batch_time 15.728897003999919 4690
Time/sps 59.555396970310326 4690
Epoch [5/50], Loss: 1.4177, Time (s): 15.7289, Test accuracy: 0.7378


100%|██████████| 938/938 [00:15<00:00, 60.53it/s]


Time/batch_time 15.50084240000001 5628
Time/sps 59.71286128676706 5628
Epoch [6/50], Loss: 1.1113, Time (s): 15.5008, Test accuracy: 0.7796


100%|██████████| 938/938 [00:15<00:00, 59.66it/s]


Time/batch_time 15.727981118999878 6566
Time/sps 59.70228877266387 6566
Epoch [7/50], Loss: 1.2264, Time (s): 15.7280, Test accuracy: 0.7796


100%|██████████| 938/938 [00:15<00:00, 60.54it/s]


Time/batch_time 15.499971401000039 7504
Time/sps 59.802833333787646 7504
Epoch [8/50], Loss: 0.8058, Time (s): 15.5000, Test accuracy: 0.8098


100%|██████████| 938/938 [00:15<00:00, 61.11it/s]


Time/batch_time 15.35651006099988 8442
Time/sps 59.94226683021991 8442
Epoch [9/50], Loss: 0.9158, Time (s): 15.3565, Test accuracy: 0.8167


100%|██████████| 938/938 [00:15<00:00, 59.23it/s]


Time/batch_time 15.841818315999944 9380
Time/sps 59.86826441805236 9380
Epoch [10/50], Loss: 0.7193, Time (s): 15.8418, Test accuracy: 0.8264


100%|██████████| 938/938 [00:15<00:00, 61.76it/s]


Time/batch_time 15.193136820999825 10318
Time/sps 60.03358233154454 10318
Epoch [11/50], Loss: 0.6991, Time (s): 15.1931, Test accuracy: 0.8357


100%|██████████| 938/938 [00:15<00:00, 58.80it/s]


Time/batch_time 15.958996814000102 11256
Time/sps 59.92669947912295 11256
Epoch [12/50], Loss: 0.7036, Time (s): 15.9590, Test accuracy: 0.8408


100%|██████████| 938/938 [00:15<00:00, 61.57it/s]


Time/batch_time 15.241125768000074 12194
Time/sps 60.048084048334005 12194
Epoch [13/50], Loss: 0.5227, Time (s): 15.2411, Test accuracy: 0.8400


100%|██████████| 938/938 [00:16<00:00, 58.58it/s]


Time/batch_time 16.018500613000015 13132
Time/sps 59.939086056651114 13132
Epoch [14/50], Loss: 0.8217, Time (s): 16.0185, Test accuracy: 0.8464


100%|██████████| 938/938 [00:15<00:00, 60.18it/s]


Time/batch_time 15.590894876999982 14070
Time/sps 59.95398298202927 14070
Epoch [15/50], Loss: 0.7074, Time (s): 15.5909, Test accuracy: 0.8538


100%|██████████| 938/938 [00:15<00:00, 59.17it/s]


Time/batch_time 15.859245593999958 15008
Time/sps 59.90279363227232 15008
Epoch [16/50], Loss: 0.5463, Time (s): 15.8592, Test accuracy: 0.8572


100%|██████████| 938/938 [00:15<00:00, 60.23it/s]


Time/batch_time 15.578625783999996 15946
Time/sps 59.92081870807858 15946
Epoch [17/50], Loss: 0.4561, Time (s): 15.5786, Test accuracy: 0.8581


100%|██████████| 938/938 [00:15<00:00, 59.07it/s]


Time/batch_time 15.885785878999968 16884
Time/sps 59.87156652311544 16884
Epoch [18/50], Loss: 0.4084, Time (s): 15.8858, Test accuracy: 0.8599


100%|██████████| 938/938 [00:15<00:00, 60.15it/s]


Time/batch_time 15.598283663999837 17822
Time/sps 59.885364558567886 17822
Epoch [19/50], Loss: 0.5078, Time (s): 15.5983, Test accuracy: 0.8612


100%|██████████| 938/938 [00:15<00:00, 60.47it/s]


Time/batch_time 15.517198839999992 18760
Time/sps 59.91329926381734 18760
Epoch [20/50], Loss: 0.4348, Time (s): 15.5172, Test accuracy: 0.8618


100%|██████████| 938/938 [00:15<00:00, 61.49it/s]


Time/batch_time 15.25876905799987 19698
Time/sps 59.985767000742655 19698
Epoch [21/50], Loss: 0.5794, Time (s): 15.2588, Test accuracy: 0.8631


100%|██████████| 938/938 [00:15<00:00, 60.69it/s]


Time/batch_time 15.460836036999808 20636
Time/sps 60.0165078418443 20636
Epoch [22/50], Loss: 0.6069, Time (s): 15.4608, Test accuracy: 0.8638


100%|██████████| 938/938 [00:15<00:00, 60.46it/s]


Time/batch_time 15.521489311000096 21574
Time/sps 60.034468691858166 21574
Epoch [23/50], Loss: 0.5777, Time (s): 15.5215, Test accuracy: 0.8665


100%|██████████| 938/938 [00:15<00:00, 60.50it/s]


Time/batch_time 15.508754401000033 22512
Time/sps 60.05298227900287 22512
Epoch [24/50], Loss: 0.5721, Time (s): 15.5088, Test accuracy: 0.8685


100%|██████████| 938/938 [00:15<00:00, 59.25it/s]


Time/batch_time 15.836807889000056 23450
Time/sps 60.01958747391699 23450
Epoch [25/50], Loss: 0.4586, Time (s): 15.8368, Test accuracy: 0.8701


100%|██████████| 938/938 [00:15<00:00, 61.27it/s]


Time/batch_time 15.313054764000071 24388
Time/sps 60.066178341820496 24388
Epoch [26/50], Loss: 0.4279, Time (s): 15.3131, Test accuracy: 0.8705


100%|██████████| 938/938 [00:16<00:00, 57.30it/s]


Time/batch_time 16.37750704599989 25326
Time/sps 59.95790501172219 25326
Epoch [27/50], Loss: 0.5242, Time (s): 16.3775, Test accuracy: 0.8703


100%|██████████| 938/938 [00:15<00:00, 59.87it/s]


Time/batch_time 15.67384512700005 26264
Time/sps 59.95386246938911 26264
Epoch [28/50], Loss: 0.5009, Time (s): 15.6738, Test accuracy: 0.8718


100%|██████████| 938/938 [00:16<00:00, 58.56it/s]


Time/batch_time 16.024889784999914 27202
Time/sps 59.90375393660433 27202
Epoch [29/50], Loss: 0.3733, Time (s): 16.0249, Test accuracy: 0.8727


100%|██████████| 938/938 [00:15<00:00, 60.32it/s]


Time/batch_time 15.556044300000167 28140
Time/sps 59.916815866395496 28140
Epoch [30/50], Loss: 0.2480, Time (s): 15.5560, Test accuracy: 0.8728


100%|██████████| 938/938 [00:16<00:00, 57.56it/s]


Time/batch_time 16.30105467699991 29078
Time/sps 59.837163396256514 29078
Epoch [31/50], Loss: 0.5434, Time (s): 16.3011, Test accuracy: 0.8727


100%|██████████| 938/938 [00:15<00:00, 61.30it/s]


Time/batch_time 15.306499741999914 30016
Time/sps 59.881257343762165 30016
Epoch [32/50], Loss: 0.5765, Time (s): 15.3065, Test accuracy: 0.8743


100%|██████████| 938/938 [00:15<00:00, 60.19it/s]


Time/batch_time 15.5936648039999 30954
Time/sps 59.88944487543077 30954
Epoch [33/50], Loss: 0.5441, Time (s): 15.5937, Test accuracy: 0.8744


100%|██████████| 938/938 [00:15<00:00, 60.95it/s]


Time/batch_time 15.394564449999962 31892
Time/sps 59.91955887086834 31892
Epoch [34/50], Loss: 0.4500, Time (s): 15.3946, Test accuracy: 0.8753


100%|██████████| 938/938 [00:15<00:00, 60.08it/s]


Time/batch_time 15.619810701000006 32830
Time/sps 59.92333320698898 32830
Epoch [35/50], Loss: 0.4267, Time (s): 15.6198, Test accuracy: 0.8748


100%|██████████| 938/938 [00:15<00:00, 61.37it/s]


Time/batch_time 15.290164972999946 33768
Time/sps 59.96197670070065 33768
Epoch [36/50], Loss: 0.2884, Time (s): 15.2902, Test accuracy: 0.8750


100%|██████████| 938/938 [00:15<00:00, 60.70it/s]


Time/batch_time 15.459952690000136 34706
Time/sps 59.98097144177595 34706
Epoch [37/50], Loss: 0.4995, Time (s): 15.4600, Test accuracy: 0.8756


100%|██████████| 938/938 [00:15<00:00, 59.87it/s]


Time/batch_time 15.672705994000125 35644
Time/sps 59.97749816740253 35644
Epoch [38/50], Loss: 0.4152, Time (s): 15.6727, Test accuracy: 0.8758


100%|██████████| 938/938 [00:15<00:00, 60.54it/s]


Time/batch_time 15.499533290000045 36582
Time/sps 59.99123532804329 36582
Epoch [39/50], Loss: 0.6616, Time (s): 15.4995, Test accuracy: 0.8751


100%|██████████| 938/938 [00:15<00:00, 59.96it/s]


Time/batch_time 15.64867801100013 37520
Time/sps 59.98998256527865 37520
Epoch [40/50], Loss: 0.3264, Time (s): 15.6487, Test accuracy: 0.8756


100%|██████████| 938/938 [00:15<00:00, 61.36it/s]


Time/batch_time 15.292053872999986 38458
Time/sps 60.02218015386508 38458
Epoch [41/50], Loss: 0.3217, Time (s): 15.2921, Test accuracy: 0.8753


100%|██████████| 938/938 [00:16<00:00, 58.24it/s]


Time/batch_time 16.11041292499999 39396
Time/sps 59.97805662733758 39396
Epoch [42/50], Loss: 0.3881, Time (s): 16.1104, Test accuracy: 0.8755


100%|██████████| 938/938 [00:15<00:00, 61.18it/s]


Time/batch_time 15.336755279999807 40334
Time/sps 60.005030510649554 40334
Epoch [43/50], Loss: 0.4484, Time (s): 15.3368, Test accuracy: 0.8752


100%|██████████| 938/938 [00:16<00:00, 58.42it/s]


Time/batch_time 16.06204036700001 41272
Time/sps 59.96753885202166 41272
Epoch [44/50], Loss: 0.6902, Time (s): 16.0620, Test accuracy: 0.8753


100%|██████████| 938/938 [00:15<00:00, 61.18it/s]


Time/batch_time 15.337536640000053 42210
Time/sps 59.993471603118024 42210
Epoch [45/50], Loss: 0.6319, Time (s): 15.3375, Test accuracy: 0.8750


100%|██████████| 938/938 [00:15<00:00, 59.26it/s]


Time/batch_time 15.839046968000048 43148
Time/sps 59.976458607455456 43148
Epoch [46/50], Loss: 0.6715, Time (s): 15.8390, Test accuracy: 0.8753


100%|██████████| 938/938 [00:15<00:00, 60.46it/s]


Time/batch_time 15.521595580000167 44086
Time/sps 59.98607802148124 44086
Epoch [47/50], Loss: 0.4450, Time (s): 15.5216, Test accuracy: 0.8759


100%|██████████| 938/938 [00:15<00:00, 60.03it/s]


Time/batch_time 15.633139947000018 45024
Time/sps 59.98638345244146 45024
Epoch [48/50], Loss: 0.5864, Time (s): 15.6331, Test accuracy: 0.8760


100%|██████████| 938/938 [00:15<00:00, 61.03it/s]


Time/batch_time 15.374486561999902 45962
Time/sps 60.006933436249824 45962
Epoch [49/50], Loss: 0.4578, Time (s): 15.3745, Test accuracy: 0.8756


100%|██████████| 938/938 [00:15<00:00, 61.28it/s]


Time/batch_time 15.312666157999956 46900
Time/sps 60.03142454372058 46900
Epoch [50/50], Loss: 0.4153, Time (s): 15.3127, Test accuracy: 0.8760
Mean time: 15.6251
Test/accuracy tensor(0.8760, device='cuda:0') 46900
Test accuracy: 0.8760
