In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchvision.transforms import v2
from train import train_model

In [2]:
mean = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
std = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

train_transform = transforms.Compose([
        v2.RandomResizedCrop(size=(224, 224), antialias=True),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
])

test_transform = transforms.Compose([
        v2.RandomResizedCrop(size=(224, 224), antialias=True),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
])

train_data = torchvision.datasets.CIFAR100("./data", train=True, 
                                     transform=train_transform,download=True)

test_data = torchvision.datasets.CIFAR100("./data", train=False, 
                                     transform=test_transform,download=True)



Files already downloaded and verified
Files already downloaded and verified


In [3]:
train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=True)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

epoches = 200
batchsize = 16
lr = 0.001

weights = ViT_B_16_Weights.DEFAULT
model = vit_b_16(weights= weights)
model.heads[0] = nn.Linear(model.heads[0].in_features, 100)
model = model.to(device)

loss_fun = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr= lr, momentum=0.9)

In [5]:
sum([p.numel() for p in model.parameters()])

85875556

In [6]:
train_model(model, train_dataloader, test_dataloader, optimizer, loss_fun, num_epoches=100, cutmix_flag=False)

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1| 100 training complete!
------------------------------
train loss: 1.990595 acc: 0.526


  1%|          | 1/100 [09:17<15:20:12, 557.70s/it]

test loss: 1.360680 acc: 0.637
Epoch 2| 100 training complete!
------------------------------
train loss: 1.199564 acc: 0.675


  2%|▏         | 2/100 [18:34<15:10:19, 557.34s/it]

test loss: 1.165969 acc: 0.682
Epoch 3| 100 training complete!
------------------------------
train loss: 1.053946 acc: 0.710


  3%|▎         | 3/100 [27:55<15:03:49, 559.06s/it]

test loss: 1.055834 acc: 0.707
Epoch 4| 100 training complete!
------------------------------
train loss: 0.958424 acc: 0.734


  4%|▍         | 4/100 [37:18<14:56:55, 560.58s/it]

test loss: 1.028227 acc: 0.714
Epoch 5| 100 training complete!
------------------------------
train loss: 0.897402 acc: 0.749


  5%|▌         | 5/100 [46:41<14:48:58, 561.46s/it]

test loss: 0.997466 acc: 0.722
Epoch 6| 100 training complete!
------------------------------
train loss: 0.832090 acc: 0.768


  6%|▌         | 6/100 [56:03<14:39:41, 561.51s/it]

test loss: 0.966523 acc: 0.733
Epoch 7| 100 training complete!
------------------------------
train loss: 0.786491 acc: 0.781


  7%|▋         | 7/100 [1:05:24<14:30:04, 561.34s/it]

test loss: 0.948437 acc: 0.738
Epoch 8| 100 training complete!
------------------------------
train loss: 0.754465 acc: 0.791


  8%|▊         | 8/100 [1:14:45<14:20:45, 561.37s/it]

test loss: 0.925172 acc: 0.744
Epoch 9| 100 training complete!
------------------------------
train loss: 0.718919 acc: 0.799


  9%|▉         | 9/100 [1:24:06<14:11:09, 561.20s/it]

test loss: 0.991022 acc: 0.729
Epoch 10| 100 training complete!
------------------------------
train loss: 0.697199 acc: 0.807


 10%|█         | 10/100 [1:33:27<14:01:39, 561.10s/it]

test loss: 0.918918 acc: 0.750
Epoch 11| 100 training complete!
------------------------------
train loss: 0.659141 acc: 0.817


 11%|█         | 11/100 [1:42:48<13:52:19, 561.12s/it]

test loss: 0.922969 acc: 0.748
Epoch 12| 100 training complete!
------------------------------
train loss: 0.638516 acc: 0.824


 12%|█▏        | 12/100 [1:52:09<13:42:48, 561.00s/it]

test loss: 0.921615 acc: 0.752
Epoch 13| 100 training complete!
------------------------------
train loss: 0.621526 acc: 0.828


 13%|█▎        | 13/100 [2:01:30<13:33:33, 561.07s/it]

test loss: 0.917754 acc: 0.751
Epoch 14| 100 training complete!
------------------------------
train loss: 0.606543 acc: 0.832


 14%|█▍        | 14/100 [2:10:51<13:24:04, 560.98s/it]

test loss: 0.911139 acc: 0.750
Epoch 15| 100 training complete!
------------------------------
train loss: 0.580921 acc: 0.839


 15%|█▌        | 15/100 [2:20:12<13:14:55, 561.12s/it]

test loss: 0.904059 acc: 0.757
Epoch 16| 100 training complete!
------------------------------
train loss: 0.564498 acc: 0.844


 16%|█▌        | 16/100 [2:29:33<13:05:27, 561.04s/it]

test loss: 0.916828 acc: 0.751
Epoch 17| 100 training complete!
------------------------------
train loss: 0.551547 acc: 0.848


 17%|█▋        | 17/100 [2:38:53<12:55:39, 560.72s/it]

test loss: 0.900637 acc: 0.756
Epoch 18| 100 training complete!
------------------------------
train loss: 0.543283 acc: 0.851


 18%|█▊        | 18/100 [2:48:14<12:46:16, 560.69s/it]

test loss: 0.906930 acc: 0.751
Epoch 19| 100 training complete!
------------------------------
train loss: 0.530186 acc: 0.854


 19%|█▉        | 19/100 [2:57:35<12:37:08, 560.85s/it]

test loss: 0.894966 acc: 0.761
Epoch 20| 100 training complete!
------------------------------
train loss: 0.508396 acc: 0.861


 20%|██        | 20/100 [3:06:56<12:27:50, 560.88s/it]

test loss: 0.923215 acc: 0.756
Epoch 21| 100 training complete!
------------------------------
train loss: 0.502516 acc: 0.863


 21%|██        | 21/100 [3:16:16<12:18:18, 560.74s/it]

test loss: 0.918056 acc: 0.755
Epoch 22| 100 training complete!
------------------------------
train loss: 0.497830 acc: 0.863


 22%|██▏       | 22/100 [3:25:37<12:08:54, 560.69s/it]

test loss: 0.916595 acc: 0.759
Epoch 23| 100 training complete!
------------------------------
train loss: 0.481023 acc: 0.867


 23%|██▎       | 23/100 [3:34:57<11:59:24, 560.58s/it]

test loss: 0.914072 acc: 0.757
Epoch 24| 100 training complete!
------------------------------
train loss: 0.472680 acc: 0.870


 24%|██▍       | 24/100 [3:44:18<11:50:01, 560.55s/it]

test loss: 0.925794 acc: 0.761
Epoch 25| 100 training complete!
------------------------------
train loss: 0.461310 acc: 0.872


 25%|██▌       | 25/100 [3:53:39<11:40:59, 560.80s/it]

test loss: 0.908009 acc: 0.766
Epoch 26| 100 training complete!
------------------------------
train loss: 0.450130 acc: 0.876


 26%|██▌       | 26/100 [4:03:00<11:31:39, 560.80s/it]

test loss: 0.927579 acc: 0.757
Epoch 27| 100 training complete!
------------------------------
train loss: 0.444705 acc: 0.879


 27%|██▋       | 27/100 [4:12:21<11:22:16, 560.78s/it]

test loss: 0.904594 acc: 0.764
Epoch 28| 100 training complete!
------------------------------
train loss: 0.435083 acc: 0.881


 28%|██▊       | 28/100 [4:21:43<11:13:22, 561.15s/it]

test loss: 0.933338 acc: 0.753
Epoch 29| 100 training complete!
------------------------------
train loss: 0.432613 acc: 0.882


 29%|██▉       | 29/100 [4:31:05<11:04:19, 561.40s/it]

test loss: 0.903728 acc: 0.764
Epoch 30| 100 training complete!
------------------------------
train loss: 0.438295 acc: 0.879


 30%|███       | 30/100 [4:40:27<10:55:13, 561.62s/it]

test loss: 0.920493 acc: 0.761
Epoch 31| 100 training complete!
------------------------------
train loss: 0.415238 acc: 0.887


 31%|███       | 31/100 [4:49:50<10:46:24, 562.10s/it]

test loss: 0.898358 acc: 0.769
Epoch 32| 100 training complete!
------------------------------
train loss: 0.415660 acc: 0.887


 32%|███▏      | 32/100 [4:59:12<10:36:55, 562.00s/it]

test loss: 0.915151 acc: 0.760
Epoch 33| 100 training complete!
------------------------------
train loss: 0.409233 acc: 0.889


 33%|███▎      | 33/100 [5:08:34<10:27:38, 562.07s/it]

test loss: 0.928504 acc: 0.757
Epoch 34| 100 training complete!
------------------------------
train loss: 0.409125 acc: 0.888


 34%|███▍      | 34/100 [5:17:57<10:18:37, 562.39s/it]

test loss: 0.930163 acc: 0.761
Epoch 35| 100 training complete!
------------------------------
train loss: 0.407857 acc: 0.889


 35%|███▌      | 35/100 [5:27:21<10:09:34, 562.68s/it]

test loss: 0.906520 acc: 0.762
Epoch 36| 100 training complete!
------------------------------
train loss: 0.394583 acc: 0.893


 36%|███▌      | 36/100 [5:36:42<9:59:57, 562.46s/it] 

test loss: 0.923648 acc: 0.762
Epoch 37| 100 training complete!
------------------------------
train loss: 0.392612 acc: 0.894


 37%|███▋      | 37/100 [5:46:05<9:50:33, 562.44s/it]

test loss: 0.946392 acc: 0.760
Epoch 38| 100 training complete!
------------------------------
train loss: 0.383498 acc: 0.895


 38%|███▊      | 38/100 [5:55:28<9:41:19, 562.57s/it]

test loss: 0.893221 acc: 0.772
Epoch 39| 100 training complete!
------------------------------
train loss: 0.384457 acc: 0.896


 39%|███▉      | 39/100 [6:04:50<9:31:45, 562.39s/it]

test loss: 0.904492 acc: 0.767
Epoch 40| 100 training complete!
------------------------------
train loss: 0.381897 acc: 0.896


 40%|████      | 40/100 [6:14:12<9:22:26, 562.43s/it]

test loss: 0.899504 acc: 0.772
Epoch 41| 100 training complete!
------------------------------
train loss: 0.376307 acc: 0.897


 41%|████      | 41/100 [6:23:33<9:12:33, 561.93s/it]

test loss: 0.976550 acc: 0.754
Epoch 42| 100 training complete!
------------------------------
train loss: 0.372218 acc: 0.899


 42%|████▏     | 42/100 [6:32:54<9:02:55, 561.64s/it]

test loss: 0.950432 acc: 0.765
Epoch 43| 100 training complete!
------------------------------
train loss: 0.371088 acc: 0.899


 43%|████▎     | 43/100 [6:42:15<8:53:24, 561.48s/it]

test loss: 0.948040 acc: 0.763
Epoch 44| 100 training complete!
------------------------------
train loss: 0.364466 acc: 0.901


 44%|████▍     | 44/100 [6:51:36<8:43:47, 561.20s/it]

test loss: 0.909026 acc: 0.769
Epoch 45| 100 training complete!
------------------------------
train loss: 0.356373 acc: 0.904


 45%|████▌     | 45/100 [7:00:57<8:34:34, 561.35s/it]

test loss: 0.937693 acc: 0.766
Epoch 46| 100 training complete!
------------------------------
train loss: 0.356996 acc: 0.903


 46%|████▌     | 46/100 [7:10:18<8:25:08, 561.27s/it]

test loss: 0.932226 acc: 0.763
Epoch 47| 100 training complete!
------------------------------
train loss: 0.348818 acc: 0.907


 47%|████▋     | 47/100 [7:19:40<8:15:46, 561.25s/it]

test loss: 0.945510 acc: 0.764
Epoch 48| 100 training complete!
------------------------------
train loss: 0.351847 acc: 0.904


 48%|████▊     | 48/100 [7:29:01<8:06:23, 561.22s/it]

test loss: 0.951144 acc: 0.760
Epoch 49| 100 training complete!
------------------------------
train loss: 0.342145 acc: 0.907


 49%|████▉     | 49/100 [7:38:22<7:56:56, 561.11s/it]

test loss: 0.959133 acc: 0.761
Epoch 50| 100 training complete!
------------------------------
train loss: 0.338435 acc: 0.909


 50%|█████     | 50/100 [7:47:43<7:47:43, 561.27s/it]

test loss: 0.905639 acc: 0.768
Epoch 51| 100 training complete!
------------------------------
train loss: 0.337767 acc: 0.909


 51%|█████     | 51/100 [7:57:05<7:38:23, 561.30s/it]

test loss: 0.942133 acc: 0.769
Epoch 52| 100 training complete!
------------------------------
train loss: 0.332597 acc: 0.910


 52%|█████▏    | 52/100 [8:06:26<7:29:07, 561.41s/it]

test loss: 0.927825 acc: 0.770
Epoch 53| 100 training complete!
------------------------------
train loss: 0.337226 acc: 0.908


 53%|█████▎    | 53/100 [8:15:48<7:19:45, 561.39s/it]

test loss: 0.931826 acc: 0.771
Epoch 54| 100 training complete!
------------------------------
train loss: 0.332947 acc: 0.910


 54%|█████▍    | 54/100 [8:25:09<7:10:19, 561.30s/it]

test loss: 0.942944 acc: 0.765
Epoch 55| 100 training complete!
------------------------------
train loss: 0.326208 acc: 0.910


 55%|█████▌    | 55/100 [8:34:31<7:01:10, 561.57s/it]

test loss: 0.932533 acc: 0.773
Epoch 56| 100 training complete!
------------------------------
train loss: 0.326141 acc: 0.911


 56%|█████▌    | 56/100 [8:43:52<6:51:39, 561.36s/it]

test loss: 0.933141 acc: 0.766
Epoch 57| 100 training complete!
------------------------------
train loss: 0.324571 acc: 0.913


 57%|█████▋    | 57/100 [8:53:13<6:42:19, 561.38s/it]

test loss: 0.978003 acc: 0.762
Epoch 58| 100 training complete!
------------------------------
train loss: 0.323262 acc: 0.913


 58%|█████▊    | 58/100 [9:02:35<6:33:01, 561.46s/it]

test loss: 0.952150 acc: 0.765
Epoch 59| 100 training complete!
------------------------------
train loss: 0.318227 acc: 0.914


 59%|█████▉    | 59/100 [9:12:06<6:25:37, 564.32s/it]

test loss: 0.963486 acc: 0.762
Epoch 60| 100 training complete!
------------------------------
train loss: 0.324426 acc: 0.912


 60%|██████    | 60/100 [9:21:47<6:19:33, 569.35s/it]

test loss: 1.000959 acc: 0.755
Epoch 61| 100 training complete!
------------------------------
train loss: 0.312326 acc: 0.916


 61%|██████    | 61/100 [9:31:26<6:11:57, 572.25s/it]

test loss: 0.942529 acc: 0.771
Epoch 62| 100 training complete!
------------------------------
train loss: 0.308610 acc: 0.916


 62%|██████▏   | 62/100 [9:41:05<6:03:39, 574.21s/it]

test loss: 0.940912 acc: 0.768
Epoch 63| 100 training complete!
------------------------------
train loss: 0.303408 acc: 0.917


 63%|██████▎   | 63/100 [9:50:44<5:55:05, 575.82s/it]

test loss: 0.948292 acc: 0.771
Epoch 64| 100 training complete!
------------------------------
train loss: 0.298788 acc: 0.919


 64%|██████▍   | 64/100 [10:00:23<5:46:03, 576.77s/it]

test loss: 0.954564 acc: 0.771
Epoch 65| 100 training complete!
------------------------------
train loss: 0.302777 acc: 0.919


 65%|██████▌   | 65/100 [10:10:05<5:37:18, 578.23s/it]

test loss: 0.952614 acc: 0.770
Epoch 66| 100 training complete!
------------------------------
train loss: 0.306213 acc: 0.917


 66%|██████▌   | 66/100 [10:19:45<5:27:56, 578.72s/it]

test loss: 0.938948 acc: 0.769
Epoch 67| 100 training complete!
------------------------------
train loss: 0.294650 acc: 0.919


 67%|██████▋   | 67/100 [10:29:25<5:18:32, 579.17s/it]

test loss: 0.963057 acc: 0.767
Epoch 68| 100 training complete!
------------------------------
train loss: 0.300875 acc: 0.918


 68%|██████▊   | 68/100 [10:39:05<5:08:59, 579.36s/it]

test loss: 0.963649 acc: 0.769
Epoch 69| 100 training complete!
------------------------------
train loss: 0.296858 acc: 0.920


 69%|██████▉   | 69/100 [10:48:46<4:59:33, 579.80s/it]

test loss: 0.959121 acc: 0.768
Epoch 70| 100 training complete!
------------------------------
train loss: 0.294852 acc: 0.919


 70%|███████   | 70/100 [10:58:27<4:50:11, 580.40s/it]

test loss: 0.972123 acc: 0.762
Epoch 71| 100 training complete!
------------------------------
train loss: 0.295727 acc: 0.919


 71%|███████   | 71/100 [11:08:08<4:40:34, 580.50s/it]

test loss: 0.958815 acc: 0.770
Epoch 72| 100 training complete!
------------------------------
train loss: 0.287632 acc: 0.923


 72%|███████▏  | 72/100 [11:17:47<4:30:39, 579.98s/it]

test loss: 0.941089 acc: 0.773
Epoch 73| 100 training complete!
------------------------------
train loss: 0.282194 acc: 0.923


 73%|███████▎  | 73/100 [11:27:19<4:19:57, 577.67s/it]

test loss: 0.968618 acc: 0.771
Epoch 74| 100 training complete!
------------------------------
train loss: 0.289040 acc: 0.921


 74%|███████▍  | 74/100 [11:36:54<4:09:52, 576.64s/it]

test loss: 1.007331 acc: 0.767
Epoch 75| 100 training complete!
------------------------------
train loss: 0.287100 acc: 0.922


 75%|███████▌  | 75/100 [11:46:28<4:00:02, 576.10s/it]

test loss: 0.916561 acc: 0.775
Epoch 76| 100 training complete!
------------------------------
train loss: 0.289116 acc: 0.922


 76%|███████▌  | 76/100 [11:56:01<3:49:58, 574.95s/it]

test loss: 0.973928 acc: 0.766
Epoch 77| 100 training complete!
------------------------------
train loss: 0.280580 acc: 0.925


 77%|███████▋  | 77/100 [12:05:33<3:40:07, 574.24s/it]

test loss: 0.944016 acc: 0.773
Epoch 78| 100 training complete!
------------------------------
train loss: 0.279179 acc: 0.923


 78%|███████▊  | 78/100 [12:15:05<3:30:16, 573.48s/it]

test loss: 0.979454 acc: 0.769
Epoch 79| 100 training complete!
------------------------------
train loss: 0.276972 acc: 0.925


 79%|███████▉  | 79/100 [12:24:37<3:20:34, 573.08s/it]

test loss: 0.985558 acc: 0.769
Epoch 80| 100 training complete!
------------------------------
train loss: 0.281609 acc: 0.924


 80%|████████  | 80/100 [12:34:09<3:10:56, 572.83s/it]

test loss: 0.979881 acc: 0.766
Epoch 81| 100 training complete!
------------------------------
train loss: 0.278842 acc: 0.924


 81%|████████  | 81/100 [12:43:41<3:01:17, 572.49s/it]

test loss: 1.022984 acc: 0.764
Epoch 82| 100 training complete!
------------------------------
train loss: 0.282608 acc: 0.924


 82%|████████▏ | 82/100 [12:53:13<2:51:43, 572.44s/it]

test loss: 0.964815 acc: 0.769
Epoch 83| 100 training complete!
------------------------------
train loss: 0.269466 acc: 0.927


 83%|████████▎ | 83/100 [13:02:45<2:42:06, 572.16s/it]

test loss: 0.993578 acc: 0.765
Epoch 84| 100 training complete!
------------------------------
train loss: 0.262633 acc: 0.928


 84%|████████▍ | 84/100 [13:12:11<2:32:04, 570.31s/it]

test loss: 0.974889 acc: 0.770
Epoch 85| 100 training complete!
------------------------------
train loss: 0.269744 acc: 0.927


 85%|████████▌ | 85/100 [13:21:22<2:21:09, 564.62s/it]

test loss: 0.959351 acc: 0.770
Epoch 86| 100 training complete!
------------------------------
train loss: 0.263064 acc: 0.929


 86%|████████▌ | 86/100 [13:30:25<2:10:14, 558.16s/it]

test loss: 0.963279 acc: 0.771
Epoch 87| 100 training complete!
------------------------------
train loss: 0.270369 acc: 0.928


 87%|████████▋ | 87/100 [13:39:27<1:59:50, 553.12s/it]

test loss: 0.967481 acc: 0.769
Epoch 88| 100 training complete!
------------------------------
train loss: 0.265927 acc: 0.927


 88%|████████▊ | 88/100 [13:48:40<1:50:37, 553.11s/it]

test loss: 0.977560 acc: 0.766
Epoch 89| 100 training complete!
------------------------------
train loss: 0.262579 acc: 0.928


 89%|████████▉ | 89/100 [13:58:12<1:42:27, 558.83s/it]

test loss: 0.988074 acc: 0.766
Epoch 90| 100 training complete!
------------------------------
train loss: 0.264495 acc: 0.928


 90%|█████████ | 90/100 [14:07:46<1:33:53, 563.31s/it]

test loss: 0.992253 acc: 0.770
Epoch 91| 100 training complete!
------------------------------
train loss: 0.257874 acc: 0.930


 91%|█████████ | 91/100 [14:17:19<1:24:56, 566.33s/it]

test loss: 0.966878 acc: 0.769
Epoch 92| 100 training complete!
------------------------------
train loss: 0.261355 acc: 0.930


 92%|█████████▏| 92/100 [14:26:53<1:15:48, 568.53s/it]

test loss: 0.990118 acc: 0.770
Epoch 93| 100 training complete!
------------------------------
train loss: 0.258457 acc: 0.930


 93%|█████████▎| 93/100 [14:36:26<1:06:29, 569.87s/it]

test loss: 0.973068 acc: 0.771
Epoch 94| 100 training complete!
------------------------------
train loss: 0.262951 acc: 0.929


 94%|█████████▍| 94/100 [14:45:58<57:03, 570.64s/it]  

test loss: 0.984199 acc: 0.766
Epoch 95| 100 training complete!
------------------------------
train loss: 0.255848 acc: 0.930


 95%|█████████▌| 95/100 [14:55:28<47:31, 570.27s/it]

test loss: 1.028302 acc: 0.768
Epoch 96| 100 training complete!
------------------------------
train loss: 0.257502 acc: 0.931


 96%|█████████▌| 96/100 [15:04:46<37:47, 566.79s/it]

test loss: 0.997249 acc: 0.775
Epoch 97| 100 training complete!
------------------------------
train loss: 0.253984 acc: 0.931


 97%|█████████▋| 97/100 [15:13:52<28:01, 560.42s/it]

test loss: 0.982034 acc: 0.765
Epoch 98| 100 training complete!
------------------------------
train loss: 0.248425 acc: 0.933


 98%|█████████▊| 98/100 [15:23:00<18:33, 556.86s/it]

test loss: 0.997151 acc: 0.767
Epoch 99| 100 training complete!
------------------------------
train loss: 0.249370 acc: 0.933


 99%|█████████▉| 99/100 [15:32:06<09:13, 553.53s/it]

test loss: 0.986786 acc: 0.773
Epoch 100| 100 training complete!
------------------------------
train loss: 0.249283 acc: 0.932


100%|██████████| 100/100 [15:41:18<00:00, 564.78s/it]

test loss: 1.007304 acc: 0.763



