In [1]:
import torch
import os
import pickle
from tqdm import tqdm
from torch.utils.data import DataLoader, ConcatDataset

from utils.model_train import train
from utils.processing import load_data, process_data, augment_data, to_tensors, split_batch, incremental_save
from utils.quickdraw_cnn import QuickDrawCNN_V1, QuickDrawCNN_V2

In [2]:
image_dir = "data/numpy_bitmap/"
categories = os.listdir(image_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
labels_map = {label: i for i, label in enumerate(set(categories))}

In [3]:
def create_dataloaders(train_datasets, test_datasets, val_datasets, batch_size=64):
    train_loader = DataLoader(ConcatDataset(train_datasets), batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(ConcatDataset(test_datasets), batch_size=batch_size, shuffle=False)
    val_loader = DataLoader(ConcatDataset(val_datasets), batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader, val_loader


def pipeline(image_dir: str, categories: list, device: torch.device, labels_map: dict):
    train_datasets, test_datasets, val_datasets = [], [], []

    for cat in tqdm(categories, desc="Processing categories"):
        # Load, process, augment, and split the data
        features, label = load_data(image_dir, cat, file_standardize=False)
        features, label = process_data(features, label)
        features, label = augment_data(features, label, rot=0, h_flip=False, v_flip=False)
        features, labels = to_tensors(features, label, labels_map, device=device)

        # Split the data into train, test, and validation sets
        train_loader, test_loader, val_loader = split_batch(features, labels, batch_size=64)
        train_datasets.append(train_loader.dataset)
        test_datasets.append(test_loader.dataset)
        val_datasets.append(val_loader.dataset)

    # Use the helper function to create the dataloaders
    train_loader, test_loader, val_loader = create_dataloaders(
        train_datasets, test_datasets, val_datasets, batch_size=64
    )
    
    return train_loader, test_loader, val_loader


In [4]:
train_loader, test_loader, val_loader = pipeline(image_dir, categories, device, labels_map)

Processing categories: 100%|██████████| 2/2 [00:00<00:00,  7.69it/s]


In [5]:
model = QuickDrawCNN_V2().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [6]:
try:
    train_loss, val_loss, train_acc, val_acc = train(model=model,
                                                     train_loader=train_loader,
                                                     val_loader=val_loader,
                                                     epochs=10,
                                                     criterion=criterion,
                                                     optimizer=optimizer,
                                                     device=device)
except KeyboardInterrupt:
    print("Training interrupted manually. Saving current model state...")
    torch.save(model.state_dict(), "model_state.pth")
    print("Model state saved.")


Epoch 1/10


Training: 100%|██████████| 3121/3121 [00:10<00:00, 301.15batch/s, acc=97.2, loss=0.0343]
Validating: 100%|██████████| 391/391 [00:00<00:00, 603.07batch/s, acc=100, loss=5.66e-5]



Epoch: 1 | Train Loss: 0.07184 | Train Acc: 97.47% | Val Loss: 0.04996 | Val Acc: 98.17%

Epoch 2/10


Training: 100%|██████████| 3121/3121 [00:10<00:00, 296.92batch/s, acc=88.9, loss=0.125]  
Validating: 100%|██████████| 391/391 [00:00<00:00, 614.43batch/s, acc=100, loss=2.68e-7] 



Epoch: 2 | Train Loss: 0.04642 | Train Acc: 98.29% | Val Loss: 0.05436 | Val Acc: 98.05%

Epoch 3/10


Training: 100%|██████████| 3121/3121 [00:10<00:00, 289.41batch/s, acc=97.2, loss=0.103]  
Validating: 100%|██████████| 391/391 [00:00<00:00, 610.72batch/s, acc=100, loss=1.16e-6] 


Epoch: 3 | Train Loss: 0.03969 | Train Acc: 98.51% | Val Loss: 0.05468 | Val Acc: 98.00%
Early stopping at epoch 3





In [7]:
parent_path = 'saves'
model_path = f'{parent_path}/model'
var_path = f'{parent_path}/vars'

varaibles_saved = incremental_save(var_path)
with open(varaibles_saved, "wb") as f:
    pickle.dump(labels_map, f)
    pickle.dump(train_loss, f)
    pickle.dump(val_loss, f)
    pickle.dump(train_acc, f)
    pickle.dump(val_acc, f)

model_saved = incremental_save(model_path, data=model)

Path: saves/vars.0.pth
Saved to saves/model.0.pth.
