In [1]:
from dataloader import training_dataset, testing_dataset
import torch
from torch.utils.data import DataLoader
from vit_pytorch import ViT
from architecture import baseline
from tqdm import tqdm
import os

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
training_data = DataLoader(training_dataset, batch_size=32, shuffle=True)
testing_data = DataLoader(testing_dataset, batch_size=32, shuffle=True)

In [4]:
vision_model = ViT(
    image_size=128 * 3,
    patch_size=32,
    num_classes=39,
    dim=512,
    depth=6,
    heads=16,
    mlp_dim=1024,
    dropout=0.1,
    emb_dropout=0.1
).to('cuda')


In [5]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(vision_model.parameters(), lr=1e-7)

In [6]:
def train_loop(dataloader, model, loss_fn=loss_fn, optimizer=optimizer, epochs=10, save_path='./saves'):
    if not os.path.exists(save_path):
        os.mkdir(save_path)

    model.to(device)

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        model.train()
        loop = tqdm(dataloader, total=len(dataloader), leave=True)
        total_loss = 0

        for batch_idx, (data, targets) in enumerate(loop):
            data = data.to(device)
            targets = targets.to(device)

            # Forward pass
            scores = model(data)
            loss = loss_fn(scores, targets)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Update progress bar
            loop.set_postfix(loss=loss.item())

        print(f"Epoch {epoch + 1} average loss: {total_loss / len(dataloader)}")

        # Save the model after each epoch
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': total_loss / len(dataloader),
        }, f"{save_path}/model_epoch_{epoch + 1}.pth")

    print("Training complete!")

In [7]:
train_loop(training_data, vision_model, epochs=100)

Epoch 1/100


100%|██████████| 64/64 [01:10<00:00,  1.09s/it, loss=15.8]


Epoch 1 average loss: 16.84277005493641
Epoch 2/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=13.2]


Epoch 2 average loss: 16.66206881403923
Epoch 3/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=15.7]


Epoch 3 average loss: 16.475641503930092
Epoch 4/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=15.2]


Epoch 4 average loss: 16.30940955877304
Epoch 5/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=15.4]


Epoch 5 average loss: 16.147932916879654
Epoch 6/100


100%|██████████| 64/64 [01:08<00:00,  1.08s/it, loss=15.7]


Epoch 6 average loss: 15.999331176280975
Epoch 7/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=13.6]


Epoch 7 average loss: 15.831502974033356
Epoch 8/100


100%|██████████| 64/64 [01:08<00:00,  1.08s/it, loss=15.5]


Epoch 8 average loss: 15.657419681549072
Epoch 9/100


100%|██████████| 64/64 [01:10<00:00,  1.10s/it, loss=15.3]


Epoch 9 average loss: 15.502362996339798
Epoch 10/100


100%|██████████| 64/64 [01:10<00:00,  1.10s/it, loss=15.6]


Epoch 10 average loss: 15.372580096125603
Epoch 11/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=14.7]


Epoch 11 average loss: 15.21179211139679
Epoch 12/100


100%|██████████| 64/64 [01:08<00:00,  1.08s/it, loss=12.3]


Epoch 12 average loss: 15.09493725001812
Epoch 13/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=14.3]


Epoch 13 average loss: 14.952007681131363
Epoch 14/100


100%|██████████| 64/64 [01:07<00:00,  1.05s/it, loss=13.5]


Epoch 14 average loss: 14.838738113641739
Epoch 15/100


100%|██████████| 64/64 [01:07<00:00,  1.06s/it, loss=15.5]


Epoch 15 average loss: 14.713290840387344
Epoch 16/100


100%|██████████| 64/64 [01:07<00:00,  1.06s/it, loss=15]  


Epoch 16 average loss: 14.58939677476883
Epoch 17/100


100%|██████████| 64/64 [01:07<00:00,  1.06s/it, loss=13.6]


Epoch 17 average loss: 14.491391852498055
Epoch 18/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=15.4]


Epoch 18 average loss: 14.398758068680763
Epoch 19/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=13.8]


Epoch 19 average loss: 14.298154219985008
Epoch 20/100


100%|██████████| 64/64 [01:07<00:00,  1.06s/it, loss=16]  


Epoch 20 average loss: 14.223856896162033
Epoch 21/100


100%|██████████| 64/64 [01:07<00:00,  1.06s/it, loss=14.7]


Epoch 21 average loss: 14.118781551718712
Epoch 22/100


100%|██████████| 64/64 [01:08<00:00,  1.08s/it, loss=14.8]


Epoch 22 average loss: 14.05745829641819
Epoch 23/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=12.9]


Epoch 23 average loss: 13.975418955087662
Epoch 24/100


100%|██████████| 64/64 [01:10<00:00,  1.10s/it, loss=14.6]


Epoch 24 average loss: 13.890080377459526
Epoch 25/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=15.5]


Epoch 25 average loss: 13.829097598791122
Epoch 26/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=13.8]


Epoch 26 average loss: 13.762872368097305
Epoch 27/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=13.3]


Epoch 27 average loss: 13.695926994085312
Epoch 28/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=14.4]


Epoch 28 average loss: 13.644237712025642
Epoch 29/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=15.2]


Epoch 29 average loss: 13.605458468198776
Epoch 30/100


100%|██████████| 64/64 [01:08<00:00,  1.08s/it, loss=13.6]


Epoch 30 average loss: 13.546604499220848
Epoch 31/100


100%|██████████| 64/64 [01:08<00:00,  1.06s/it, loss=13.1]


Epoch 31 average loss: 13.492343112826347
Epoch 32/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=15.5]


Epoch 32 average loss: 13.459567695856094
Epoch 33/100


100%|██████████| 64/64 [01:07<00:00,  1.06s/it, loss=11.8]


Epoch 33 average loss: 13.415278434753418
Epoch 34/100


100%|██████████| 64/64 [01:07<00:00,  1.06s/it, loss=16]  


Epoch 34 average loss: 13.380952686071396
Epoch 35/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=13.7]


Epoch 35 average loss: 13.319440454244614
Epoch 36/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=11.4]


Epoch 36 average loss: 13.283379316329956
Epoch 37/100


100%|██████████| 64/64 [01:08<00:00,  1.06s/it, loss=14.8]


Epoch 37 average loss: 13.267118155956268
Epoch 38/100


100%|██████████| 64/64 [01:10<00:00,  1.10s/it, loss=13.1]


Epoch 38 average loss: 13.223471909761429
Epoch 39/100


100%|██████████| 64/64 [01:08<00:00,  1.08s/it, loss=11.3]


Epoch 39 average loss: 13.19005797803402
Epoch 40/100


100%|██████████| 64/64 [01:08<00:00,  1.08s/it, loss=12.7]


Epoch 40 average loss: 13.15562716126442
Epoch 41/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=11.9]


Epoch 41 average loss: 13.13057129085064
Epoch 42/100


100%|██████████| 64/64 [01:10<00:00,  1.11s/it, loss=15.4]


Epoch 42 average loss: 13.12477320432663
Epoch 43/100


100%|██████████| 64/64 [01:17<00:00,  1.21s/it, loss=11.8]


Epoch 43 average loss: 13.087032437324524
Epoch 44/100


100%|██████████| 64/64 [01:15<00:00,  1.19s/it, loss=12.1]


Epoch 44 average loss: 13.063168421387672
Epoch 45/100


100%|██████████| 64/64 [01:10<00:00,  1.11s/it, loss=15.6]


Epoch 45 average loss: 13.036425963044167
Epoch 46/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=11.5]


Epoch 46 average loss: 13.011937364935875
Epoch 47/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=12.3]


Epoch 47 average loss: 12.99502269923687
Epoch 48/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=13.3]


Epoch 48 average loss: 12.97428160905838
Epoch 49/100


100%|██████████| 64/64 [01:15<00:00,  1.18s/it, loss=15.3]


Epoch 49 average loss: 12.95095644891262
Epoch 50/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=12.3]


Epoch 50 average loss: 12.932830154895782
Epoch 51/100


100%|██████████| 64/64 [01:08<00:00,  1.08s/it, loss=11.8]


Epoch 51 average loss: 12.925012215971947
Epoch 52/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=15.4]


Epoch 52 average loss: 12.903960168361664
Epoch 53/100


100%|██████████| 64/64 [01:10<00:00,  1.10s/it, loss=13.5]


Epoch 53 average loss: 12.884847030043602
Epoch 54/100


100%|██████████| 64/64 [01:10<00:00,  1.10s/it, loss=12.4]


Epoch 54 average loss: 12.85574959218502
Epoch 55/100


100%|██████████| 64/64 [01:11<00:00,  1.11s/it, loss=13.9]


Epoch 55 average loss: 12.84326285123825
Epoch 56/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=14.9]


Epoch 56 average loss: 12.84473367035389
Epoch 57/100


100%|██████████| 64/64 [01:11<00:00,  1.12s/it, loss=13.4]


Epoch 57 average loss: 12.839870125055313
Epoch 58/100


100%|██████████| 64/64 [01:11<00:00,  1.11s/it, loss=12.3]


Epoch 58 average loss: 12.784158036112785
Epoch 59/100


100%|██████████| 64/64 [01:11<00:00,  1.12s/it, loss=14]  


Epoch 59 average loss: 12.793582171201706
Epoch 60/100


100%|██████████| 64/64 [01:11<00:00,  1.11s/it, loss=11.4]


Epoch 60 average loss: 12.773007452487946
Epoch 61/100


100%|██████████| 64/64 [01:10<00:00,  1.10s/it, loss=13.1]


Epoch 61 average loss: 12.766188204288483
Epoch 62/100


100%|██████████| 64/64 [01:10<00:00,  1.11s/it, loss=13.2]


Epoch 62 average loss: 12.762589290738106
Epoch 63/100


100%|██████████| 64/64 [01:08<00:00,  1.08s/it, loss=13.7]


Epoch 63 average loss: 12.74703335762024
Epoch 64/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=12.6]


Epoch 64 average loss: 12.731297954916954
Epoch 65/100


100%|██████████| 64/64 [01:08<00:00,  1.08s/it, loss=11.7]


Epoch 65 average loss: 12.725455835461617
Epoch 66/100


100%|██████████| 64/64 [01:14<00:00,  1.16s/it, loss=13.8]


Epoch 66 average loss: 12.714159399271011
Epoch 67/100


100%|██████████| 64/64 [01:28<00:00,  1.39s/it, loss=14]  


Epoch 67 average loss: 12.704297721385956
Epoch 68/100


100%|██████████| 64/64 [01:28<00:00,  1.38s/it, loss=12.8]


Epoch 68 average loss: 12.706323772668839
Epoch 69/100


100%|██████████| 64/64 [01:27<00:00,  1.37s/it, loss=13.4]


Epoch 69 average loss: 12.68847505748272
Epoch 70/100


100%|██████████| 64/64 [01:23<00:00,  1.30s/it, loss=13.7]


Epoch 70 average loss: 12.67367747426033
Epoch 71/100


100%|██████████| 64/64 [01:24<00:00,  1.33s/it, loss=12]  


Epoch 71 average loss: 12.66847874224186
Epoch 72/100


100%|██████████| 64/64 [01:24<00:00,  1.32s/it, loss=14.1]


Epoch 72 average loss: 12.651372894644737
Epoch 73/100


100%|██████████| 64/64 [01:25<00:00,  1.34s/it, loss=12.9]


Epoch 73 average loss: 12.647148326039314
Epoch 74/100


100%|██████████| 64/64 [01:25<00:00,  1.34s/it, loss=13.7]


Epoch 74 average loss: 12.649397820234299
Epoch 75/100


100%|██████████| 64/64 [01:24<00:00,  1.32s/it, loss=11.4]


Epoch 75 average loss: 12.644207075238228
Epoch 76/100


100%|██████████| 64/64 [01:23<00:00,  1.31s/it, loss=14.2]


Epoch 76 average loss: 12.641956806182861
Epoch 77/100


100%|██████████| 64/64 [01:24<00:00,  1.32s/it, loss=12.7]


Epoch 77 average loss: 12.617999285459518
Epoch 78/100


100%|██████████| 64/64 [01:27<00:00,  1.36s/it, loss=13]  


Epoch 78 average loss: 12.639856964349747
Epoch 79/100


100%|██████████| 64/64 [01:23<00:00,  1.30s/it, loss=12.4]


Epoch 79 average loss: 12.606842532753944
Epoch 80/100


100%|██████████| 64/64 [01:26<00:00,  1.35s/it, loss=11.3]


Epoch 80 average loss: 12.601649135351181
Epoch 81/100


100%|██████████| 64/64 [01:23<00:00,  1.31s/it, loss=11]  


Epoch 81 average loss: 12.596442878246307
Epoch 82/100


100%|██████████| 64/64 [01:23<00:00,  1.31s/it, loss=13.5]


Epoch 82 average loss: 12.598374679684639
Epoch 83/100


100%|██████████| 64/64 [01:23<00:00,  1.30s/it, loss=13.9]


Epoch 83 average loss: 12.589273110032082
Epoch 84/100


100%|██████████| 64/64 [01:22<00:00,  1.29s/it, loss=12.3]


Epoch 84 average loss: 12.587448805570602
Epoch 85/100


100%|██████████| 64/64 [01:23<00:00,  1.31s/it, loss=14.8]


Epoch 85 average loss: 12.573984071612358
Epoch 86/100


100%|██████████| 64/64 [01:23<00:00,  1.31s/it, loss=12.4]


Epoch 86 average loss: 12.561858713626862
Epoch 87/100


100%|██████████| 64/64 [01:24<00:00,  1.32s/it, loss=12.9]


Epoch 87 average loss: 12.546852812170982
Epoch 88/100


100%|██████████| 64/64 [01:24<00:00,  1.33s/it, loss=15.1]


Epoch 88 average loss: 12.567009910941124
Epoch 89/100


100%|██████████| 64/64 [01:24<00:00,  1.32s/it, loss=14.3]


Epoch 89 average loss: 12.555438324809074
Epoch 90/100


100%|██████████| 64/64 [01:25<00:00,  1.33s/it, loss=13.4]


Epoch 90 average loss: 12.56875628232956
Epoch 91/100


100%|██████████| 64/64 [01:32<00:00,  1.45s/it, loss=15]  


Epoch 91 average loss: 12.558229252696037
Epoch 92/100


100%|██████████| 64/64 [01:24<00:00,  1.32s/it, loss=12]  


Epoch 92 average loss: 12.53929091989994
Epoch 93/100


100%|██████████| 64/64 [01:25<00:00,  1.33s/it, loss=11.6]


Epoch 93 average loss: 12.539132177829742
Epoch 94/100


100%|██████████| 64/64 [01:28<00:00,  1.39s/it, loss=14]  


Epoch 94 average loss: 12.540482684969902
Epoch 95/100


100%|██████████| 64/64 [01:25<00:00,  1.34s/it, loss=14.5]


Epoch 95 average loss: 12.539007723331451
Epoch 96/100


100%|██████████| 64/64 [01:25<00:00,  1.34s/it, loss=10.9]


Epoch 96 average loss: 12.51643317937851
Epoch 97/100


100%|██████████| 64/64 [01:29<00:00,  1.39s/it, loss=12.9]


Epoch 97 average loss: 12.510909870266914
Epoch 98/100


100%|██████████| 64/64 [01:23<00:00,  1.30s/it, loss=10]  


Epoch 98 average loss: 12.507078230381012
Epoch 99/100


100%|██████████| 64/64 [01:23<00:00,  1.30s/it, loss=11.4]


Epoch 99 average loss: 12.506341353058815
Epoch 100/100


100%|██████████| 64/64 [01:24<00:00,  1.32s/it, loss=13.1]


Epoch 100 average loss: 12.520131275057793
Training complete!
