In [9]:
from dataloader import training_dataset, testing_dataset
import torch
from torch.utils.data import DataLoader
from vit_pytorch import ViT, SimpleViT
from torch import nn
from tqdm import tqdm
import os

In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [11]:
training_data = DataLoader(training_dataset, batch_size=32, shuffle=True)
testing_data = DataLoader(testing_dataset, batch_size=32, shuffle=True)

In [12]:
vision_model = ViT(
    image_size=128 * 3,
    patch_size=32,
    num_classes=29,
    dim=512,
    depth=6,
    heads=16,
    mlp_dim=1024,
    dropout = 0.1,
    emb_dropout = 0.1
).to('cuda')

# vit_model = nn.Sequential(vision_model, nn.Sigmoid())

In [13]:
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(vision_model.parameters(), lr=1e-7)

In [14]:
cache = torch.load('saves/exp_2/model_epoch_32.pth')

In [15]:
optimizer.load_state_dict(cache['optimizer_state_dict'])
vision_model.load_state_dict(cache['model_state_dict'])

<All keys matched successfully>

In [16]:
state = {'loss': [], 'epochs': [], 'test_loss': []}
def train_loop(dataloader, testing_data, model, loss_fn, optimizer, epochs=10, save_path='./saves'):
    global state
    if not os.path.exists(save_path):
        os.mkdir(save_path)

    all_exps = [int(elm.replace('exp_', '')) if elm != '.ipynb_checkpoints' else -1 for elm in os.listdir(save_path)]

    current_num = max(all_exps) if (len(all_exps) > 0) else 1
    save_path = save_path + '/exp_' + str(current_num + 1)
    os.mkdir(save_path)

    model.to(device)

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        model.train()
        loop = tqdm(dataloader, total=len(dataloader), leave=True)
        total_loss = 0

        for batch_idx, (data, targets) in enumerate(loop):
            data = data.to(device)
            targets = targets.to(device)

            # Forward pass
            scores = model(data)
            loss = loss_fn(scores, targets)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Update progress bar
            loop.set_postfix(loss=loss.item())

        testing_loss = 0
        for test in tqdm(testing_data):
            X, y = test
            with torch.no_grad():
                model.eval()

                score = model(X)
                testing_loss =+ loss_fn(score, y)

                model.train()

        print(f"Epoch {epoch + 1} average loss: {total_loss / len(dataloader)} with testing loss of {testing_loss / len(testing_data)}")

        state['loss'].append(total_loss / len(dataloader))
        state['epochs'].append(epoch + 1)
        state['test_loss'].append((testing_loss / len(testing_data)).cpu())

        # Save the model after each epoch
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': total_loss / len(dataloader),
            'testing_loss': (testing_loss / len(testing_data)).cpu()
        }, f"{save_path}/model_epoch_{epoch + 1}.pth")

    print("Training complete!")

In [17]:
train_loop(training_data, testing_data, vision_model, loss_fn, optimizer, epochs=10)

Epoch 1/10


100%|██████████| 64/64 [01:20<00:00,  1.26s/it, loss=0.416]
100%|██████████| 13/13 [00:13<00:00,  1.07s/it]


Epoch 1 average loss: 0.4333819351159036 with testing loss of 0.0302905086427927
Epoch 2/10


100%|██████████| 64/64 [01:10<00:00,  1.10s/it, loss=0.413]
100%|██████████| 13/13 [00:12<00:00,  1.03it/s]


Epoch 2 average loss: 0.4277427545748651 with testing loss of 0.03043803945183754
Epoch 3/10


100%|██████████| 64/64 [01:10<00:00,  1.11s/it, loss=0.424]
100%|██████████| 13/13 [00:13<00:00,  1.00s/it]


Epoch 3 average loss: 0.42264205776154995 with testing loss of 0.03406708687543869
Epoch 4/10


100%|██████████| 64/64 [01:12<00:00,  1.14s/it, loss=0.416]
100%|██████████| 13/13 [00:13<00:00,  1.04s/it]


Epoch 4 average loss: 0.4169648280367255 with testing loss of 0.03461531549692154
Epoch 5/10


100%|██████████| 64/64 [01:11<00:00,  1.11s/it, loss=0.401]
100%|██████████| 13/13 [00:13<00:00,  1.00s/it]


Epoch 5 average loss: 0.41227435041218996 with testing loss of 0.03187219053506851
Epoch 6/10


100%|██████████| 64/64 [01:10<00:00,  1.10s/it, loss=0.41] 
100%|██████████| 13/13 [00:13<00:00,  1.01s/it]


Epoch 6 average loss: 0.4074586592614651 with testing loss of 0.035555724054574966
Epoch 7/10


100%|██████████| 64/64 [01:11<00:00,  1.11s/it, loss=0.393]
100%|██████████| 13/13 [00:12<00:00,  1.03it/s]


Epoch 7 average loss: 0.40280989138409495 with testing loss of 0.02893841080367565
Epoch 8/10


100%|██████████| 64/64 [01:11<00:00,  1.12s/it, loss=0.377]
100%|██████████| 13/13 [00:12<00:00,  1.03it/s]


Epoch 8 average loss: 0.3976817522197962 with testing loss of 0.02897035889327526
Epoch 9/10


100%|██████████| 64/64 [01:11<00:00,  1.12s/it, loss=0.384]
100%|██████████| 13/13 [00:13<00:00,  1.01s/it]


Epoch 9 average loss: 0.3936554677784443 with testing loss of 0.02405449002981186
Epoch 10/10


100%|██████████| 64/64 [01:11<00:00,  1.11s/it, loss=0.369]
100%|██████████| 13/13 [00:13<00:00,  1.04s/it]


Epoch 10 average loss: 0.3888186230324209 with testing loss of 0.027012156322598457
Training complete!


In [9]:
state.keys()

dict_keys(['loss', 'epochs', 'test_loss'])