In [12]:
from dataloader import training_dataset, testing_dataset
import torch
from torch.utils.data import DataLoader
from vit_pytorch import ViT
from architecture import baseline
from tqdm import tqdm
import os

In [17]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
training_data = DataLoader(training_dataset, batch_size=32, shuffle=True)
testing_data = DataLoader(testing_dataset, batch_size=32, shuffle=True)

In [13]:
vision_model = ViT(
    image_size=128 * 3,
    patch_size=32,
    num_classes=39,
    dim=512,
    depth=6,
    heads=16,
    mlp_dim=1024,
    dropout=0.1,
    emb_dropout=0.1
).to('cuda')


In [14]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters())

In [15]:
def train_loop(dataloader, model, loss_fn=loss_fn, optimizer=optimizer, epochs=10, save_path='./saves'):
    if not os.path.exists(save_path):
        os.mkdir(save_path)

    model.to(device)

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        model.train()
        loop = tqdm(dataloader, total=len(dataloader), leave=True)
        total_loss = 0

        for batch_idx, (data, targets) in enumerate(loop):
            data = data.to(device)
            targets = targets.to(device)

            # Forward pass
            scores = model(data)
            loss = loss_fn(scores, targets)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Update progress bar
            loop.set_postfix(loss=loss.item())

        print(f"Epoch {epoch + 1} average loss: {total_loss / len(dataloader)}")

        # Save the model after each epoch
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': total_loss / len(dataloader),
        }, f"{save_path}/model_epoch_{epoch + 1}.pth")

    print("Training complete!")

In [None]:
train_loop(training_data, vision_model, epochs=10)

Epoch 1/10


 36%|███▌      | 23/64 [00:25<00:44,  1.09s/it, loss=17.4]