In [1]:
from dataloader import training_dataset, testing_dataset
import torch
from torch.utils.data import DataLoader
from vit_pytorch import ViT, SimpleViT
from torch import nn
from tqdm import tqdm
import os

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
training_data = DataLoader(training_dataset, batch_size=32, shuffle=True)
testing_data = DataLoader(testing_dataset, batch_size=32, shuffle=True)

In [4]:
vision_model = ViT(
    image_size=128 * 3,
    patch_size=32,
    num_classes=39,
    dim=512,
    depth=6,
    heads=16,
    mlp_dim=1024,
).to('cuda')

# vit_model = nn.Sequential(vision_model, nn.Sigmoid())

In [5]:
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(vision_model.parameters(), lr=1e-7)

In [6]:
def train_loop(dataloader, testing_data, model, loss_fn, optimizer, state={'loss': [], 'epochs': [], 'test_loss': []}, epochs=10, save_path='./saves'):
    if not os.path.exists(save_path):
        os.mkdir(save_path)

    all_exps = [int(elm.replace('exp_', '')) if elm != '.ipynb_checkpoints' else -1 for elm in os.listdir(save_path)]

    current_num = max(all_exps) if (len(all_exps) > 0) else 1
    save_path = save_path + '/exp_' + str(current_num + 1)
    os.mkdir(save_path)

    model.to(device)

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        model.train()
        loop = tqdm(dataloader, total=len(dataloader), leave=True)
        total_loss = 0

        for batch_idx, (data, targets) in enumerate(loop):
            data = data.to(device)
            targets = targets.to(device)

            # Forward pass
            scores = model(data)
            loss = loss_fn(scores, targets)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Update progress bar
            loop.set_postfix(loss=loss.item())

        testing_loss = 0
        for test in tqdm(testing_data):
            X, y = test
            with torch.no_grad():
                model.eval()

                score = model(X)
                testing_loss =+ loss_fn(score, y)

                model.train()

        print(f"Epoch {epoch + 1} average loss: {total_loss / len(dataloader)} with testing loss of {testing_loss / len(testing_data)}")

        state['loss'].append(total_loss / len(dataloader))
        state['epochs'].append(epoch + 1)
        state['test_loss'].append(testing_loss / len(testing_data))

        # Save the model after each epoch
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': total_loss / len(dataloader),
            'testing_loss': testing_loss / len(testing_data)
        }, f"{save_path}/model_epoch_{epoch + 1}.pth")

    print("Training complete!")

In [7]:
state = train_loop(training_data, testing_data, vision_model, loss_fn, optimizer, epochs=100)
torch.save(state, './stats.pth')

Epoch 1/100


100%|██████████| 64/64 [01:10<00:00,  1.10s/it, loss=0.783]
100%|██████████| 13/13 [00:12<00:00,  1.00it/s]


Epoch 1 average loss: 0.788705300539732 with testing loss of 0.06024235114455223
Epoch 2/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=0.775]
100%|██████████| 13/13 [00:12<00:00,  1.02it/s]


Epoch 2 average loss: 0.7749776365235448 with testing loss of 0.05804092437028885
Epoch 3/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=0.757]
100%|██████████| 13/13 [00:12<00:00,  1.03it/s]


Epoch 3 average loss: 0.7615182269364595 with testing loss of 0.05732303112745285
Epoch 4/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=0.741]
100%|██████████| 13/13 [00:12<00:00,  1.02it/s]


Epoch 4 average loss: 0.7483181962743402 with testing loss of 0.057295992970466614
Epoch 5/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=0.726]
100%|██████████| 13/13 [00:12<00:00,  1.01it/s]


Epoch 5 average loss: 0.7353732539340854 with testing loss of 0.0549936480820179
Epoch 6/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=0.722]
100%|██████████| 13/13 [00:12<00:00,  1.01it/s]


Epoch 6 average loss: 0.7226880323141813 with testing loss of 0.05570589378476143
Epoch 7/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=0.71] 
100%|██████████| 13/13 [00:12<00:00,  1.02it/s]


Epoch 7 average loss: 0.7102065831422806 with testing loss of 0.053925033658742905
Epoch 8/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=0.688]
100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Epoch 8 average loss: 0.6978872725740075 with testing loss of 0.053008612245321274
Epoch 9/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=0.681]
100%|██████████| 13/13 [00:12<00:00,  1.03it/s]


Epoch 9 average loss: 0.6858319351449609 with testing loss of 0.05168682709336281
Epoch 10/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=0.665]
100%|██████████| 13/13 [00:12<00:00,  1.01it/s]


Epoch 10 average loss: 0.6739389635622501 with testing loss of 0.05018429458141327
Epoch 11/100


100%|██████████| 64/64 [01:08<00:00,  1.08s/it, loss=0.667]
100%|██████████| 13/13 [00:12<00:00,  1.03it/s]


Epoch 11 average loss: 0.6623102109879255 with testing loss of 0.05070415139198303
Epoch 12/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=0.652]
100%|██████████| 13/13 [00:12<00:00,  1.02it/s]


Epoch 12 average loss: 0.6508122086524963 with testing loss of 0.048902567476034164
Epoch 13/100


100%|██████████| 64/64 [01:08<00:00,  1.08s/it, loss=0.636]
100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Epoch 13 average loss: 0.6395142050459981 with testing loss of 0.04738008230924606
Epoch 14/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=0.628]
100%|██████████| 13/13 [00:12<00:00,  1.02it/s]


Epoch 14 average loss: 0.62842708081007 with testing loss of 0.04792407900094986
Epoch 15/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=0.601]
100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Epoch 15 average loss: 0.6174851423129439 with testing loss of 0.04746522009372711
Epoch 16/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=0.602]
100%|██████████| 13/13 [00:12<00:00,  1.03it/s]


Epoch 16 average loss: 0.6068114852532744 with testing loss of 0.04885761812329292
Epoch 17/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=0.596]
100%|██████████| 13/13 [00:12<00:00,  1.03it/s]


Epoch 17 average loss: 0.596326969563961 with testing loss of 0.04563378542661667
Epoch 18/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=0.592]
100%|██████████| 13/13 [00:12<00:00,  1.00it/s]


Epoch 18 average loss: 0.5860449392348528 with testing loss of 0.04532644525170326
Epoch 19/100


100%|██████████| 64/64 [01:07<00:00,  1.06s/it, loss=0.58] 
100%|██████████| 13/13 [00:13<00:00,  1.02s/it]


Epoch 19 average loss: 0.5759360352531075 with testing loss of 0.04546135291457176
Epoch 20/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=0.566]
100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Epoch 20 average loss: 0.5660385387018323 with testing loss of 0.043051827698946
Epoch 21/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=0.557]
100%|██████████| 13/13 [00:12<00:00,  1.06it/s]


Epoch 21 average loss: 0.5563562856987119 with testing loss of 0.046165142208337784
Epoch 22/100


100%|██████████| 64/64 [01:08<00:00,  1.08s/it, loss=0.538]
100%|██████████| 13/13 [00:12<00:00,  1.00it/s]


Epoch 22 average loss: 0.5468513565137982 with testing loss of 0.03983147442340851
Epoch 23/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=0.54] 
100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Epoch 23 average loss: 0.5376337738707662 with testing loss of 0.040869999676942825
Epoch 24/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=0.522]
100%|██████████| 13/13 [00:12<00:00,  1.02it/s]


Epoch 24 average loss: 0.5285677146166563 with testing loss of 0.0410231351852417
Epoch 25/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=0.532]
100%|██████████| 13/13 [00:12<00:00,  1.03it/s]


Epoch 25 average loss: 0.5198148596100509 with testing loss of 0.03838128596544266
Epoch 26/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=0.505]
100%|██████████| 13/13 [00:12<00:00,  1.01it/s]


Epoch 26 average loss: 0.5111454604193568 with testing loss of 0.04114057123661041
Epoch 27/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=0.505]
100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Epoch 27 average loss: 0.5028000492602587 with testing loss of 0.0402546226978302
Epoch 28/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=0.481]
100%|██████████| 13/13 [00:12<00:00,  1.01it/s]


Epoch 28 average loss: 0.4945832509547472 with testing loss of 0.04065763205289841
Epoch 29/100


  2%|▏         | 1/64 [00:01<01:28,  1.40s/it, loss=0.484]


KeyboardInterrupt: 