In [1]:

import torch
from torch import nn, optim



In [2]:

try:
  from utils import save_experiment, save_checkpoint
  from data import prepare_data
  from vit import ViTForClassfication
except:
  print("[INFO] Couldn't find going_modular or helper_functions scripts... downloading them from GitHub.")
  !git clone https://github.com/tintn/vision-transformer-from-scratch/
  !mv vision-transformer-from-scratch/utils.py . # get the utils.py script
  !mv vision-transformer-from-scratch/data.py . # get the data.py script
  !mv vision-transformer-from-scratch/vit.py .

[INFO] Couldn't find going_modular or helper_functions scripts... downloading them from GitHub.
Cloning into 'vision-transformer-from-scratch'...
remote: Enumerating objects: 83, done.[K
remote: Counting objects: 100% (83/83), done.[K
remote: Compressing objects: 100% (66/66), done.[K
remote: Total 83 (delta 43), reused 43 (delta 17), pack-reused 0 (from 0)[K
Receiving objects: 100% (83/83), 1.34 MiB | 1.45 MiB/s, done.
Resolving deltas: 100% (43/43), done.
mv: missing destination file operand after 'vision-transformer-from-scratch/vit.py'
Try 'mv --help' for more information.


In [4]:
from utils import save_experiment, save_checkpoint
from data import prepare_data
from vit import ViTForClassfication

In [5]:
# !cd vision-transformer-from-scratch/
# !ls

In [6]:

# !mv vision-transformer-from-scratch/utils.py . # get the utils.py script
# !mv vision-transformer-from-scratch/data.py . # get the data.py script
# !mv vision-transformer-from-scratch/vit.py
# !rm -rf vision-transformer-from-scratch

In [7]:


from tqdm import tqdm
config = {
    "patch_size": 4,  # Input image size: 32x32 -> 8x8 patches
    "hidden_size": 48,
    "num_hidden_layers": 4,
    "num_attention_heads": 4,
    "intermediate_size": 4 * 48, # 4 * hidden_size
    "hidden_dropout_prob": 0.0,
    "attention_probs_dropout_prob": 0.0,
    "initializer_range": 0.02,
    "image_size": 32,
    "num_classes": 10, # num_classes of CIFAR10
    "num_channels": 3,
    "qkv_bias": True,
    "use_faster_attention": True,
}
# These are not hard constraints, but are used to prevent misconfigurations
assert config["hidden_size"] % config["num_attention_heads"] == 0
assert config['intermediate_size'] == 4 * config['hidden_size']
assert config['image_size'] % config['patch_size'] == 0


class Trainer:
    """
    The simple trainer.
    """

    def __init__(self, model, optimizer, loss_fn, exp_name, device):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.exp_name = exp_name
        self.device = device

    def train(self, trainloader, testloader, epochs, save_model_every_n_epochs=0):
        """
        Train the model for the specified number of epochs.
        """
        # Keep track of the losses and accuracies
        train_losses, test_losses, accuracies = [], [], []
        # Train the model
        for i in tqdm(range(epochs), desc="Training"):
            train_loss = self.train_epoch(trainloader)
            accuracy, test_loss = self.evaluate(testloader)
            train_losses.append(train_loss)
            test_losses.append(test_loss)
            accuracies.append(accuracy)
            print(f"Epoch: {i+1}, Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}, Accuracy: {accuracy:.4f}")
            if save_model_every_n_epochs > 0 and (i+1) % save_model_every_n_epochs == 0 and i+1 != epochs:
                print('\tSave checkpoint at epoch', i+1)
                save_checkpoint(self.exp_name, self.model, i+1)
        # Save the experiment
        save_experiment(self.exp_name, config, self.model, train_losses, test_losses, accuracies)

    def train_epoch(self, trainloader):
        """
        Train the model for one epoch.
        """
        self.model.train()
        total_loss = 0
        for batch in trainloader:
            # Move the batch to the device
            batch = [t.to(self.device) for t in batch]
            images, labels = batch
            # Zero the gradients
            self.optimizer.zero_grad()
            # Calculate the loss
            loss = self.loss_fn(self.model(images)[0], labels)
            # Backpropagate the loss
            loss.backward()
            # Update the model's parameters
            self.optimizer.step()
            total_loss += loss.item() * len(images)
        return total_loss / len(trainloader.dataset)

    @torch.no_grad()
    def evaluate(self, testloader):
        self.model.eval()
        total_loss = 0
        correct = 0
        with torch.no_grad():
            for batch in testloader:
                # Move the batch to the device
                batch = [t.to(self.device) for t in batch]
                images, labels = batch

                # Get predictions
                logits, _ = self.model(images)

                # Calculate the loss
                loss = self.loss_fn(logits, labels)
                total_loss += loss.item() * len(images)

                # Calculate the accuracy
                predictions = torch.argmax(logits, dim=1)
                correct += torch.sum(predictions == labels).item()
        accuracy = correct / len(testloader.dataset)
        avg_loss = total_loss / len(testloader.dataset)
        return accuracy, avg_loss




In [8]:

def get_hyperparameters():
    hyperparameters = {
        "exp_name": "default_experiment",
        "batch_size": 256,
        "epochs": 151,
        "lr": 1e-2,
        "device": "cuda" if torch.cuda.is_available() else "cpu",
        "save_model_every": 50
    }
    return hyperparameters

# Example usage:
# hyperparameters = get_hyperparameters()
# print(hyperparameters)


ViTForClassfication(
  (embedding): Embeddings(
    (patch_embeddings): PatchEmbeddings(
      (projection): Conv2d(3, 48, kernel_size=(4, 4), stride=(4, 4))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): Encoder(
    (blocks): ModuleList(
      (0-3): 4 x Block(
        (attention): FasterMultiHeadAttention(
          (qkv_projection): Linear(in_features=48, out_features=144, bias=True)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (output_projection): Linear(in_features=48, out_features=48, bias=True)
          (output_dropout): Dropout(p=0.0, inplace=False)
        )
        (layernorm_1): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (dense_1): Linear(in_features=48, out_features=192, bias=True)
          (activation): NewGELUActivation()
          (dense_2): Linear(in_features=192, out_features=48, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (layernorm_2): LayerNorm((4

<__main__.Trainer at 0x7da3a0309840>

In [9]:


hyperparameters = get_hyperparameters()

# Training parameters
exp_name = hyperparameters["exp_name"]
batch_size = hyperparameters["batch_size"]
epochs = hyperparameters["epochs"]
lr = hyperparameters["lr"]
device = hyperparameters["device"]
save_model_every_n_epochs = hyperparameters["save_model_every"]
trainloader, testloader, _ = prepare_data(batch_size=batch_size)
# Create the model, optimizer, loss function and trainer
model = ViTForClassfication(config)
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
loss_fn = nn.CrossEntropyLoss()
trainer = Trainer(model, optimizer, loss_fn,exp_name, device=device)
trainer.train(trainloader, testloader, epochs, save_model_every_n_epochs=save_model_every_n_epochs)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:12<00:00, 13489935.18it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Training:   0%|          | 1/500 [00:34<4:45:55, 34.38s/it]

Epoch: 1, Train loss: 1.8506, Test loss: 1.6934, Accuracy: 0.3795


Training:   0%|          | 2/500 [01:06<4:34:08, 33.03s/it]

Epoch: 2, Train loss: 1.5675, Test loss: 1.4708, Accuracy: 0.4655


Training:   1%|          | 3/500 [01:42<4:45:24, 34.46s/it]

Epoch: 3, Train loss: 1.4379, Test loss: 1.4014, Accuracy: 0.4860


Training:   1%|          | 4/500 [02:15<4:38:18, 33.67s/it]

Epoch: 4, Train loss: 1.3766, Test loss: 1.3511, Accuracy: 0.5140


Training:   1%|          | 5/500 [02:47<4:35:12, 33.36s/it]

Epoch: 5, Train loss: 1.3297, Test loss: 1.2605, Accuracy: 0.5503


Training:   1%|          | 6/500 [03:20<4:32:30, 33.10s/it]

Epoch: 6, Train loss: 1.2647, Test loss: 1.2773, Accuracy: 0.5314


Training:   1%|▏         | 7/500 [03:54<4:33:09, 33.24s/it]

Epoch: 7, Train loss: 1.2409, Test loss: 1.2529, Accuracy: 0.5524


Training:   2%|▏         | 8/500 [04:27<4:32:27, 33.23s/it]

Epoch: 8, Train loss: 1.2025, Test loss: 1.2019, Accuracy: 0.5686


Training:   2%|▏         | 9/500 [04:59<4:29:26, 32.93s/it]

Epoch: 9, Train loss: 1.1744, Test loss: 1.2302, Accuracy: 0.5530


Training:   2%|▏         | 10/500 [05:32<4:29:34, 33.01s/it]

Epoch: 10, Train loss: 1.1448, Test loss: 1.1718, Accuracy: 0.5718


Training:   2%|▏         | 11/500 [06:05<4:29:28, 33.06s/it]

Epoch: 11, Train loss: 1.1254, Test loss: 1.2076, Accuracy: 0.5734


Training:   2%|▏         | 12/500 [06:39<4:29:19, 33.11s/it]

Epoch: 12, Train loss: 1.1027, Test loss: 1.0961, Accuracy: 0.6078


Training:   3%|▎         | 13/500 [07:11<4:26:48, 32.87s/it]

Epoch: 13, Train loss: 1.0735, Test loss: 1.1497, Accuracy: 0.5875


Training:   3%|▎         | 14/500 [07:46<4:30:35, 33.41s/it]

Epoch: 14, Train loss: 1.0626, Test loss: 1.0748, Accuracy: 0.6149


Training:   3%|▎         | 15/500 [08:18<4:28:24, 33.21s/it]

Epoch: 15, Train loss: 1.0367, Test loss: 1.1282, Accuracy: 0.5990


Training:   3%|▎         | 16/500 [08:51<4:27:19, 33.14s/it]

Epoch: 16, Train loss: 1.0114, Test loss: 1.0153, Accuracy: 0.6327


Training:   3%|▎         | 17/500 [09:24<4:26:44, 33.14s/it]

Epoch: 17, Train loss: 1.0039, Test loss: 1.0567, Accuracy: 0.6218


Training:   4%|▎         | 18/500 [09:59<4:29:07, 33.50s/it]

Epoch: 18, Train loss: 0.9939, Test loss: 1.0565, Accuracy: 0.6175


Training:   4%|▍         | 19/500 [10:31<4:26:06, 33.19s/it]

Epoch: 19, Train loss: 0.9651, Test loss: 0.9961, Accuracy: 0.6435


Training:   4%|▍         | 20/500 [11:04<4:23:25, 32.93s/it]

Epoch: 20, Train loss: 0.9475, Test loss: 0.9860, Accuracy: 0.6458


Training:   4%|▍         | 21/500 [11:37<4:24:42, 33.16s/it]

Epoch: 21, Train loss: 0.9395, Test loss: 0.9679, Accuracy: 0.6582


Training:   4%|▍         | 22/500 [12:11<4:24:40, 33.22s/it]

Epoch: 22, Train loss: 0.9262, Test loss: 0.9625, Accuracy: 0.6571


Training:   5%|▍         | 23/500 [12:44<4:24:24, 33.26s/it]

Epoch: 23, Train loss: 0.9122, Test loss: 1.0232, Accuracy: 0.6420


Training:   5%|▍         | 24/500 [13:16<4:21:06, 32.91s/it]

Epoch: 24, Train loss: 0.8975, Test loss: 0.9226, Accuracy: 0.6772


Training:   5%|▌         | 25/500 [13:49<4:21:40, 33.05s/it]

Epoch: 25, Train loss: 0.8864, Test loss: 0.9878, Accuracy: 0.6585


Training:   5%|▌         | 26/500 [14:23<4:22:02, 33.17s/it]

Epoch: 26, Train loss: 0.8886, Test loss: 0.9789, Accuracy: 0.6478


Training:   5%|▌         | 27/500 [14:57<4:22:52, 33.35s/it]

Epoch: 27, Train loss: 0.8687, Test loss: 0.9123, Accuracy: 0.6821


Training:   6%|▌         | 28/500 [15:29<4:20:02, 33.06s/it]

Epoch: 28, Train loss: 0.8596, Test loss: 0.9622, Accuracy: 0.6690


Training:   6%|▌         | 29/500 [16:03<4:20:47, 33.22s/it]

Epoch: 29, Train loss: 0.8414, Test loss: 0.8846, Accuracy: 0.6887


Training:   6%|▌         | 30/500 [16:36<4:21:04, 33.33s/it]

Epoch: 30, Train loss: 0.8360, Test loss: 0.9635, Accuracy: 0.6643


Training:   6%|▌         | 31/500 [17:10<4:21:49, 33.50s/it]

Epoch: 31, Train loss: 0.8256, Test loss: 0.9037, Accuracy: 0.6827


Training:   6%|▋         | 32/500 [17:43<4:19:05, 33.22s/it]

Epoch: 32, Train loss: 0.8178, Test loss: 0.9699, Accuracy: 0.6607


Training:   7%|▋         | 33/500 [18:17<4:21:50, 33.64s/it]

Epoch: 33, Train loss: 0.8135, Test loss: 0.9147, Accuracy: 0.6784


Training:   7%|▋         | 34/500 [18:50<4:18:23, 33.27s/it]

Epoch: 34, Train loss: 0.8083, Test loss: 0.8841, Accuracy: 0.6921


Training:   7%|▋         | 35/500 [19:23<4:17:06, 33.18s/it]

Epoch: 35, Train loss: 0.7972, Test loss: 0.8592, Accuracy: 0.6976


Training:   7%|▋         | 36/500 [19:56<4:17:13, 33.26s/it]

Epoch: 36, Train loss: 0.7892, Test loss: 0.8565, Accuracy: 0.7010


Training:   7%|▋         | 37/500 [20:31<4:21:13, 33.85s/it]

Epoch: 37, Train loss: 0.7837, Test loss: 0.8982, Accuracy: 0.6804


Training:   8%|▊         | 38/500 [21:04<4:17:37, 33.46s/it]

Epoch: 38, Train loss: 0.7731, Test loss: 0.8631, Accuracy: 0.6963


Training:   8%|▊         | 39/500 [21:37<4:16:30, 33.38s/it]

Epoch: 39, Train loss: 0.7748, Test loss: 0.8089, Accuracy: 0.7184


Training:   8%|▊         | 40/500 [22:10<4:15:22, 33.31s/it]

Epoch: 40, Train loss: 0.7567, Test loss: 0.8445, Accuracy: 0.7045


Training:   8%|▊         | 41/500 [22:44<4:16:43, 33.56s/it]

Epoch: 41, Train loss: 0.7598, Test loss: 0.8826, Accuracy: 0.6886


Training:   8%|▊         | 42/500 [23:18<4:16:17, 33.58s/it]

Epoch: 42, Train loss: 0.7661, Test loss: 0.8924, Accuracy: 0.6943


Training:   9%|▊         | 43/500 [23:50<4:12:33, 33.16s/it]

Epoch: 43, Train loss: 0.7519, Test loss: 0.8479, Accuracy: 0.7032


Training:   9%|▉         | 44/500 [24:24<4:13:44, 33.39s/it]

Epoch: 44, Train loss: 0.7501, Test loss: 0.8412, Accuracy: 0.7031


Training:   9%|▉         | 45/500 [24:59<4:15:46, 33.73s/it]

Epoch: 45, Train loss: 0.7282, Test loss: 0.8056, Accuracy: 0.7193


Training:   9%|▉         | 46/500 [25:32<4:14:57, 33.70s/it]

Epoch: 46, Train loss: 0.7260, Test loss: 0.8370, Accuracy: 0.7100


Training:   9%|▉         | 47/500 [26:05<4:11:43, 33.34s/it]

Epoch: 47, Train loss: 0.7185, Test loss: 0.8239, Accuracy: 0.7085


Training:  10%|▉         | 48/500 [26:38<4:11:15, 33.35s/it]

Epoch: 48, Train loss: 0.7243, Test loss: 0.7880, Accuracy: 0.7238


Training:  10%|▉         | 49/500 [27:12<4:11:03, 33.40s/it]

Epoch: 49, Train loss: 0.7251, Test loss: 0.8497, Accuracy: 0.7126


Training:  10%|█         | 50/500 [27:45<4:11:33, 33.54s/it]

Epoch: 50, Train loss: 0.7101, Test loss: 0.8678, Accuracy: 0.7056


Training:  10%|█         | 51/500 [28:18<4:08:35, 33.22s/it]

Epoch: 51, Train loss: 0.7078, Test loss: 0.8428, Accuracy: 0.7064


Training:  10%|█         | 52/500 [28:53<4:11:28, 33.68s/it]

Epoch: 52, Train loss: 0.7007, Test loss: 0.7838, Accuracy: 0.7279


Training:  11%|█         | 53/500 [29:25<4:08:36, 33.37s/it]

Epoch: 53, Train loss: 0.6899, Test loss: 0.8789, Accuracy: 0.7018


Training:  11%|█         | 54/500 [29:59<4:09:09, 33.52s/it]

Epoch: 54, Train loss: 0.6962, Test loss: 0.7950, Accuracy: 0.7224


Training:  11%|█         | 55/500 [30:32<4:06:29, 33.23s/it]

Epoch: 55, Train loss: 0.6939, Test loss: 0.7666, Accuracy: 0.7328


Training:  11%|█         | 56/500 [31:07<4:10:16, 33.82s/it]

Epoch: 56, Train loss: 0.6858, Test loss: 0.7891, Accuracy: 0.7302


Training:  11%|█▏        | 57/500 [31:40<4:07:15, 33.49s/it]

Epoch: 57, Train loss: 0.6793, Test loss: 0.7932, Accuracy: 0.7239


Training:  12%|█▏        | 58/500 [32:14<4:07:27, 33.59s/it]

Epoch: 58, Train loss: 0.6768, Test loss: 0.8063, Accuracy: 0.7197


Training:  12%|█▏        | 59/500 [32:47<4:05:43, 33.43s/it]

Epoch: 59, Train loss: 0.6775, Test loss: 0.8563, Accuracy: 0.7047


Training:  12%|█▏        | 60/500 [33:22<4:09:15, 33.99s/it]

Epoch: 60, Train loss: 0.6732, Test loss: 0.8079, Accuracy: 0.7222


Training:  12%|█▏        | 61/500 [33:55<4:06:05, 33.64s/it]

Epoch: 61, Train loss: 0.6813, Test loss: 0.7810, Accuracy: 0.7279


Training:  12%|█▏        | 62/500 [34:28<4:05:52, 33.68s/it]

Epoch: 62, Train loss: 0.6499, Test loss: 0.7607, Accuracy: 0.7343


Training:  13%|█▎        | 63/500 [35:01<4:02:00, 33.23s/it]

Epoch: 63, Train loss: 0.6570, Test loss: 0.7776, Accuracy: 0.7300


Training:  13%|█▎        | 64/500 [35:36<4:06:04, 33.86s/it]

Epoch: 64, Train loss: 0.6558, Test loss: 0.7881, Accuracy: 0.7280


Training:  13%|█▎        | 65/500 [36:08<4:02:01, 33.38s/it]

Epoch: 65, Train loss: 0.6528, Test loss: 0.7839, Accuracy: 0.7286


Training:  13%|█▎        | 66/500 [36:42<4:01:50, 33.44s/it]

Epoch: 66, Train loss: 0.6591, Test loss: 0.7470, Accuracy: 0.7392


Training:  13%|█▎        | 67/500 [37:14<3:59:10, 33.14s/it]

Epoch: 67, Train loss: 0.6354, Test loss: 0.7447, Accuracy: 0.7423


Training:  14%|█▎        | 68/500 [37:49<4:02:26, 33.67s/it]

Epoch: 68, Train loss: 0.6357, Test loss: 0.7311, Accuracy: 0.7440


Training:  14%|█▍        | 69/500 [38:22<4:00:11, 33.44s/it]

Epoch: 69, Train loss: 0.6353, Test loss: 0.7622, Accuracy: 0.7362


Training:  14%|█▍        | 70/500 [38:56<3:59:55, 33.48s/it]

Epoch: 70, Train loss: 0.6319, Test loss: 0.7328, Accuracy: 0.7474


Training:  14%|█▍        | 71/500 [39:30<4:00:16, 33.60s/it]

Epoch: 71, Train loss: 0.6319, Test loss: 0.8250, Accuracy: 0.7127


Training:  14%|█▍        | 72/500 [40:04<4:02:22, 33.98s/it]

Epoch: 72, Train loss: 0.6311, Test loss: 0.7244, Accuracy: 0.7488


Training:  15%|█▍        | 73/500 [40:37<3:59:32, 33.66s/it]

Epoch: 73, Train loss: 0.6317, Test loss: 0.7808, Accuracy: 0.7334


Training:  15%|█▍        | 74/500 [41:11<3:58:33, 33.60s/it]

Epoch: 74, Train loss: 0.6172, Test loss: 0.7657, Accuracy: 0.7390


Training:  15%|█▌        | 75/500 [41:44<3:58:02, 33.61s/it]

Epoch: 75, Train loss: 0.6215, Test loss: 0.7720, Accuracy: 0.7348


Training:  15%|█▌        | 76/500 [42:18<3:58:13, 33.71s/it]

Epoch: 76, Train loss: 0.6241, Test loss: 0.7817, Accuracy: 0.7303


Training:  15%|█▌        | 77/500 [42:51<3:55:33, 33.41s/it]

Epoch: 77, Train loss: 0.6216, Test loss: 0.7394, Accuracy: 0.7461


Training:  16%|█▌        | 78/500 [43:25<3:56:46, 33.66s/it]

Epoch: 78, Train loss: 0.6095, Test loss: 0.7769, Accuracy: 0.7359


Training:  16%|█▌        | 79/500 [44:00<3:57:31, 33.85s/it]

Epoch: 79, Train loss: 0.6069, Test loss: 0.7697, Accuracy: 0.7346


Training:  16%|█▌        | 80/500 [44:34<3:57:11, 33.88s/it]

Epoch: 80, Train loss: 0.6141, Test loss: 0.7616, Accuracy: 0.7386


Training:  16%|█▌        | 81/500 [45:06<3:54:00, 33.51s/it]

Epoch: 81, Train loss: 0.6160, Test loss: 0.7373, Accuracy: 0.7462


Training:  16%|█▋        | 82/500 [45:40<3:54:02, 33.60s/it]

Epoch: 82, Train loss: 0.6047, Test loss: 0.8794, Accuracy: 0.7106


Training:  17%|█▋        | 83/500 [46:14<3:54:19, 33.72s/it]

Epoch: 83, Train loss: 0.6046, Test loss: 0.7143, Accuracy: 0.7551


Training:  17%|█▋        | 84/500 [46:48<3:53:57, 33.74s/it]

Epoch: 84, Train loss: 0.6055, Test loss: 0.7411, Accuracy: 0.7450


Training:  17%|█▋        | 85/500 [47:21<3:51:32, 33.47s/it]

Epoch: 85, Train loss: 0.6093, Test loss: 0.7182, Accuracy: 0.7544


Training:  17%|█▋        | 86/500 [47:55<3:52:06, 33.64s/it]

Epoch: 86, Train loss: 0.6025, Test loss: 0.7604, Accuracy: 0.7425


Training:  17%|█▋        | 87/500 [48:29<3:52:22, 33.76s/it]

Epoch: 87, Train loss: 0.5944, Test loss: 0.7237, Accuracy: 0.7540


Training:  18%|█▊        | 88/500 [49:03<3:52:16, 33.83s/it]

Epoch: 88, Train loss: 0.5962, Test loss: 0.7220, Accuracy: 0.7465


Training:  18%|█▊        | 89/500 [49:35<3:49:04, 33.44s/it]

Epoch: 89, Train loss: 0.5901, Test loss: 0.7203, Accuracy: 0.7593


Training:  18%|█▊        | 90/500 [50:09<3:48:37, 33.46s/it]

Epoch: 90, Train loss: 0.5941, Test loss: 0.7563, Accuracy: 0.7450


Training:  18%|█▊        | 91/500 [50:41<3:46:33, 33.24s/it]

Epoch: 91, Train loss: 0.5888, Test loss: 0.7034, Accuracy: 0.7558


Training:  18%|█▊        | 92/500 [51:15<3:46:16, 33.28s/it]

Epoch: 92, Train loss: 0.5886, Test loss: 0.7872, Accuracy: 0.7324


Training:  19%|█▊        | 93/500 [51:47<3:43:03, 32.88s/it]

Epoch: 93, Train loss: 0.5810, Test loss: 0.7360, Accuracy: 0.7468


Training:  19%|█▉        | 94/500 [52:21<3:44:42, 33.21s/it]

Epoch: 94, Train loss: 0.5781, Test loss: 0.7256, Accuracy: 0.7581


Training:  19%|█▉        | 95/500 [52:54<3:44:13, 33.22s/it]

Epoch: 95, Train loss: 0.5832, Test loss: 0.7473, Accuracy: 0.7454


Training:  19%|█▉        | 96/500 [53:26<3:41:56, 32.96s/it]

Epoch: 96, Train loss: 0.5879, Test loss: 0.7424, Accuracy: 0.7414


Training:  19%|█▉        | 97/500 [54:00<3:42:36, 33.14s/it]

Epoch: 97, Train loss: 0.5756, Test loss: 0.7526, Accuracy: 0.7467


Training:  20%|█▉        | 98/500 [54:33<3:42:43, 33.24s/it]

Epoch: 98, Train loss: 0.5753, Test loss: 0.7330, Accuracy: 0.7543


Training:  20%|█▉        | 99/500 [55:07<3:41:54, 33.20s/it]

Epoch: 99, Train loss: 0.5740, Test loss: 0.7177, Accuracy: 0.7541


Training:  20%|██        | 100/500 [55:39<3:39:03, 32.86s/it]

Epoch: 100, Train loss: 0.5686, Test loss: 0.7525, Accuracy: 0.7464


Training:  20%|██        | 101/500 [56:12<3:40:23, 33.14s/it]

Epoch: 101, Train loss: 0.5670, Test loss: 0.8168, Accuracy: 0.7296


Training:  20%|██        | 102/500 [56:46<3:41:34, 33.40s/it]

Epoch: 102, Train loss: 0.5844, Test loss: 0.7469, Accuracy: 0.7370


Training:  21%|██        | 103/500 [57:20<3:41:37, 33.50s/it]

Epoch: 103, Train loss: 0.5724, Test loss: 0.7461, Accuracy: 0.7436


Training:  21%|██        | 104/500 [57:52<3:38:51, 33.16s/it]

Epoch: 104, Train loss: 0.5787, Test loss: 0.7585, Accuracy: 0.7471


Training:  21%|██        | 105/500 [58:26<3:39:57, 33.41s/it]

Epoch: 105, Train loss: 0.5739, Test loss: 0.7487, Accuracy: 0.7455


Training:  21%|██        | 106/500 [59:00<3:40:09, 33.53s/it]

Epoch: 106, Train loss: 0.5676, Test loss: 0.7429, Accuracy: 0.7506


Training:  21%|██▏       | 107/500 [59:34<3:40:12, 33.62s/it]

Epoch: 107, Train loss: 0.5693, Test loss: 0.7308, Accuracy: 0.7509


Training:  22%|██▏       | 108/500 [1:00:06<3:36:51, 33.19s/it]

Epoch: 108, Train loss: 0.5575, Test loss: 0.7328, Accuracy: 0.7485


Training:  22%|██▏       | 109/500 [1:00:39<3:35:27, 33.06s/it]

Epoch: 109, Train loss: 0.5590, Test loss: 0.7220, Accuracy: 0.7593


Training:  22%|██▏       | 110/500 [1:01:13<3:36:25, 33.30s/it]

Epoch: 110, Train loss: 0.5705, Test loss: 0.7608, Accuracy: 0.7396


Training:  22%|██▏       | 111/500 [1:01:46<3:34:41, 33.12s/it]

Epoch: 111, Train loss: 0.5654, Test loss: 0.7127, Accuracy: 0.7536


Training:  22%|██▏       | 112/500 [1:02:19<3:35:07, 33.27s/it]

Epoch: 112, Train loss: 0.5634, Test loss: 0.7316, Accuracy: 0.7533


Training:  23%|██▎       | 113/500 [1:02:52<3:33:06, 33.04s/it]

Epoch: 113, Train loss: 0.5625, Test loss: 0.7123, Accuracy: 0.7645


Training:  23%|██▎       | 114/500 [1:03:27<3:35:56, 33.57s/it]

Epoch: 114, Train loss: 0.5551, Test loss: 0.7271, Accuracy: 0.7523


Training:  23%|██▎       | 115/500 [1:03:59<3:33:13, 33.23s/it]

Epoch: 115, Train loss: 0.5580, Test loss: 0.7159, Accuracy: 0.7601


Training:  23%|██▎       | 116/500 [1:04:33<3:34:00, 33.44s/it]

Epoch: 116, Train loss: 0.5627, Test loss: 0.7341, Accuracy: 0.7566


Training:  23%|██▎       | 117/500 [1:05:05<3:31:36, 33.15s/it]

Epoch: 117, Train loss: 0.5572, Test loss: 0.7631, Accuracy: 0.7413


Training:  24%|██▎       | 118/500 [1:05:40<3:34:49, 33.74s/it]

Epoch: 118, Train loss: 0.5500, Test loss: 0.7703, Accuracy: 0.7437


Training:  24%|██▍       | 119/500 [1:06:13<3:32:33, 33.47s/it]

Epoch: 119, Train loss: 0.5570, Test loss: 0.7552, Accuracy: 0.7459


Training:  24%|██▍       | 120/500 [1:06:47<3:32:18, 33.52s/it]

Epoch: 120, Train loss: 0.5513, Test loss: 0.6916, Accuracy: 0.7565


Training:  24%|██▍       | 121/500 [1:07:19<3:29:25, 33.16s/it]

Epoch: 121, Train loss: 0.5623, Test loss: 0.7507, Accuracy: 0.7504


Training:  24%|██▍       | 122/500 [1:07:54<3:31:48, 33.62s/it]

Epoch: 122, Train loss: 0.5589, Test loss: 0.6805, Accuracy: 0.7718


Training:  25%|██▍       | 123/500 [1:08:26<3:28:16, 33.15s/it]

Epoch: 123, Train loss: 0.5501, Test loss: 0.7300, Accuracy: 0.7540


Training:  25%|██▍       | 124/500 [1:08:59<3:27:54, 33.18s/it]

Epoch: 124, Train loss: 0.5521, Test loss: 0.7235, Accuracy: 0.7504


Training:  25%|██▌       | 125/500 [1:09:32<3:27:09, 33.15s/it]

Epoch: 125, Train loss: 0.5412, Test loss: 0.7369, Accuracy: 0.7557


Training:  25%|██▌       | 126/500 [1:10:07<3:30:14, 33.73s/it]

Epoch: 126, Train loss: 0.5436, Test loss: 0.8200, Accuracy: 0.7268


Training:  25%|██▌       | 127/500 [1:10:40<3:27:17, 33.34s/it]

Epoch: 127, Train loss: 0.5440, Test loss: 0.7079, Accuracy: 0.7562


Training:  26%|██▌       | 128/500 [1:11:13<3:26:01, 33.23s/it]

Epoch: 128, Train loss: 0.5344, Test loss: 0.7187, Accuracy: 0.7592


Training:  26%|██▌       | 129/500 [1:11:47<3:27:37, 33.58s/it]

Epoch: 129, Train loss: 0.5496, Test loss: 0.7597, Accuracy: 0.7420


Training:  26%|██▌       | 130/500 [1:12:21<3:26:30, 33.49s/it]

Epoch: 130, Train loss: 0.5428, Test loss: 0.8099, Accuracy: 0.7332


Training:  26%|██▌       | 131/500 [1:12:53<3:24:19, 33.22s/it]

Epoch: 131, Train loss: 0.5392, Test loss: 0.7474, Accuracy: 0.7454


Training:  26%|██▋       | 132/500 [1:13:25<3:22:08, 32.96s/it]

Epoch: 132, Train loss: 0.5438, Test loss: 0.7237, Accuracy: 0.7555


Training:  27%|██▋       | 133/500 [1:14:00<3:24:46, 33.48s/it]

Epoch: 133, Train loss: 0.5493, Test loss: 0.7559, Accuracy: 0.7474


Training:  27%|██▋       | 134/500 [1:14:32<3:21:49, 33.09s/it]

Epoch: 134, Train loss: 0.5363, Test loss: 0.7512, Accuracy: 0.7524


Training:  27%|██▋       | 135/500 [1:15:05<3:21:25, 33.11s/it]

Epoch: 135, Train loss: 0.5333, Test loss: 0.7350, Accuracy: 0.7517


Training:  27%|██▋       | 136/500 [1:15:37<3:18:52, 32.78s/it]

Epoch: 136, Train loss: 0.5316, Test loss: 0.8367, Accuracy: 0.7301


Training:  27%|██▋       | 137/500 [1:16:12<3:22:03, 33.40s/it]

Epoch: 137, Train loss: 0.5348, Test loss: 0.7842, Accuracy: 0.7444


Training:  28%|██▊       | 138/500 [1:16:45<3:19:36, 33.09s/it]

Epoch: 138, Train loss: 0.5375, Test loss: 0.7059, Accuracy: 0.7628


Training:  28%|██▊       | 139/500 [1:17:18<3:19:52, 33.22s/it]

Epoch: 139, Train loss: 0.5384, Test loss: 0.7410, Accuracy: 0.7499


Training:  28%|██▊       | 140/500 [1:17:51<3:17:45, 32.96s/it]

Epoch: 140, Train loss: 0.5362, Test loss: 0.7504, Accuracy: 0.7479


Training:  28%|██▊       | 141/500 [1:18:26<3:21:03, 33.60s/it]

Epoch: 141, Train loss: 0.5375, Test loss: 0.6775, Accuracy: 0.7705


Training:  28%|██▊       | 142/500 [1:18:58<3:18:06, 33.20s/it]

Epoch: 142, Train loss: 0.5279, Test loss: 0.7156, Accuracy: 0.7582


Training:  29%|██▊       | 143/500 [1:19:31<3:16:58, 33.10s/it]

Epoch: 143, Train loss: 0.5269, Test loss: 0.6986, Accuracy: 0.7663


Training:  29%|██▉       | 144/500 [1:20:04<3:15:40, 32.98s/it]

Epoch: 144, Train loss: 0.5295, Test loss: 0.7394, Accuracy: 0.7534


Training:  29%|██▉       | 145/500 [1:20:37<3:16:38, 33.24s/it]

Epoch: 145, Train loss: 0.5218, Test loss: 0.7351, Accuracy: 0.7566


Training:  29%|██▉       | 146/500 [1:21:11<3:17:00, 33.39s/it]

Epoch: 146, Train loss: 0.5337, Test loss: 0.8026, Accuracy: 0.7339


Training:  29%|██▉       | 147/500 [1:21:44<3:14:42, 33.09s/it]

Epoch: 147, Train loss: 0.5262, Test loss: 0.6757, Accuracy: 0.7713


Training:  30%|██▉       | 148/500 [1:22:18<3:17:24, 33.65s/it]

Epoch: 148, Train loss: 0.5281, Test loss: 0.6994, Accuracy: 0.7636


Training:  30%|██▉       | 149/500 [1:22:53<3:18:22, 33.91s/it]

Epoch: 149, Train loss: 0.5300, Test loss: 0.6877, Accuracy: 0.7700


Training:  30%|███       | 150/500 [1:23:27<3:17:46, 33.90s/it]

Epoch: 150, Train loss: 0.5218, Test loss: 0.6888, Accuracy: 0.7687


Training:  30%|███       | 150/500 [1:23:30<3:14:51, 33.40s/it]


KeyboardInterrupt: 

In [18]:

import torch

def save_checkpoint(model, optimizer, epoch, path="checkpoint.pth", scheduler=None):
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }
    if scheduler is not None:
        checkpoint["scheduler_state_dict"] = scheduler.state_dict()

    torch.save(checkpoint, path)
    print(f"Checkpoint saved at epoch {epoch}.")

# Example usage:

model = trainer.model
optimizer = trainer.optimizer
scheduler = None
epoch =  150
save_checkpoint(model, optimizer, epoch, path="checkpoint_1.pth", scheduler=scheduler)




Checkpoint saved at epoch 150.


In [16]:


def load_checkpoint(path="checkpoint.pth", model=None, optimizer=None, scheduler=None):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

    if scheduler is not None:
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])

    epoch = checkpoint["epoch"]
    print(f"Checkpoint loaded. Resuming from epoch {epoch}.")

    return epoch

# Example usage:
# start_epoch = load_checkpoint(path="my_checkpoint.pth", model=model, optimizer=optimizer, scheduler=scheduler)



In [10]:
raise EOFError

NameError: name 'model' is not defined

In [11]:

epochs

NameError: name 'epochs' is not defined

In [12]:

optimizer

NameError: name 'optimizer' is not defined

In [None]:
def load_checkpoint(path="checkpoint.pth", model=None, optimizer=None, scheduler=None):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

    if scheduler is not None:
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])

    epoch = checkpoint["epoch"]
    print(f"Checkpoint loaded. Resuming from epoch {epoch}.")

    return epoch

# Example usage:
start_epoch = load_checkpoint(path="my_checkpoint.pth", model=model, optimizer=optimizer, scheduler=scheduler)


In [None]:

raise ZeroDivisionError