In [1]:
import torch
import os
import numpy as np
import pickle
from tqdm import tqdm   
from torch.utils.data import DataLoader, ConcatDataset

from utils.model_train import train
from utils.processing import load_data, process_data, augment_data, to_tensors, split_batch, incremental_save, partition_datasets
from utils.quickdraw_cnn import QuickDrawCNN_V1, QuickDrawCNN_V2

In [2]:
image_dir = "data/numpy_bitmap/"
datasets_dir = "data/datasets"
categories = os.listdir(image_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
labels_map = {label: i for i, label in enumerate(set(categories))}

In [3]:
def pipeline(image_dir: str, categories: list, device: torch.device, labels_map: dict):
    
    '''
    A pipeline function to run the entire process of loading, processing, augmenting, and splitting the data onto each category at a time.
    '''
    
    train_datasets, test_datasets, val_datasets = [], [], []

    for cat in tqdm(categories, desc="Processing categories"):
        
        try:
            # load, process, augment, and split the data
            features, label = load_data(image_dir, cat, file_standardize=False)
            features, label = process_data(features, label)
            mask = np.random.rand(len(features)) <= 0.4 # drop 60% of the data
            features = features[mask]
            # features, label = augment_data(features, label, rot=0, h_flip=False, v_flip=False)
            features, labels = to_tensors(features, label, labels_map, device=device)
            
            # split the data into train, test, and validation sets
            train_loader, test_loader, val_loader = split_batch(features, labels, batch_size=32)
            train_datasets.append(train_loader.dataset)
            test_datasets.append(test_loader.dataset)
            val_datasets.append(val_loader.dataset)
        except RuntimeError as e:
            print(f"Error on {cat}: {e}")
        
    return train_datasets, test_datasets, val_datasets

In [None]:
num_files = 15 

train_file_paths = [f"{datasets_dir}/train/train_datasets_{i}.pkl" for i in range(num_files)]
test_file_paths = [f"{datasets_dir}/test/test_datasets_{i}.pkl" for i in range(num_files)]
val_file_paths = [f"{datasets_dir}/val/val_datasets_{i}.pkl" for i in range(num_files)]

In [10]:
for i in tqdm(range(161, len(categories), len(categories) // num_files), desc="Processing categorical split"):
    print(f"Processing categories {i} to {min(i + len(categories) // num_files, len(categories))}...")
    
    train_buffers = [[] for _ in range(num_files)]
    test_buffers = [[] for _ in range(num_files)]
    val_buffers = [[] for _ in range(num_files)]
        
    try:
        train_datasets, test_datasets, val_datasets = pipeline(image_dir, categories[i:i + len(categories) // num_files], device, labels_map)
        torch.cuda.empty_cache()

        # Distribute datasets into buffers
        for idx, (train_data, test_data, val_data) in enumerate(tqdm(zip(train_datasets, test_datasets, val_datasets), desc="Distributing datasets", total=len(train_datasets))):
            file_idx = idx % num_files
            train_buffers[file_idx].append(train_data)
            test_buffers[file_idx].append(test_data)
            val_buffers[file_idx].append(val_data)

        # Save the buffers to their respective files
        for file_idx in tqdm(range(num_files), desc="Saving buffers"):
            with open(train_file_paths[file_idx], "wb") as f:
                pickle.dump(train_buffers[file_idx], f)
            with open(test_file_paths[file_idx], "wb") as f:
                pickle.dump(test_buffers[file_idx], f)
            with open(val_file_paths[file_idx], "wb") as f:
                pickle.dump(val_buffers[file_idx], f)

        # Clear buffers to free up memory
        del train_datasets, test_datasets, val_datasets
        del train_data, test_data, val_data
        del train_buffers, test_buffers, val_buffers
        
        
        torch.cuda.empty_cache()

    except RuntimeError as e:
        print(f"Error on {i}: {e}")

In [None]:
# load the data back
train_datasets = []
test_datasets = []

print(os.listdir(f"{datasets_dir}/train"))

for file in tqdm(os.listdir(f"{datasets_dir}/train"), desc="Loading train datasets"):
    with open(f"{datasets_dir}/train/{file}", "rb") as f:
        train_datasets.extend(pickle.load(f))
for file in tqdm(os.listdir(f"{datasets_dir}/test"), desc="Loading test datasets"):
    with open(f"{datasets_dir}/test/{file}", "rb") as f:
        test_datasets.extend(pickle.load(f))

In [None]:
# concat dataloaders
train_loader = DataLoader(ConcatDataset(train_datasets), batch_size=64, shuffle=True)
test_loader = DataLoader(ConcatDataset(test_datasets), batch_size=64, shuffle=False)
val_loader = DataLoader(ConcatDataset(val_datasets), batch_size=64, shuffle=False)

In [57]:
model = QuickDrawCNN_V2().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
try:
    train_loss, val_loss, train_acc, val_acc = train(model=model,
                                                     train_loader=train_loader,
                                                     val_loader=val_loader,
                                                     epochs=10,
                                                     criterion=criterion,
                                                     optimizer=optimizer,
                                                     device=device)
except KeyboardInterrupt:
    print("Training interrupted manually. Saving current model state...")
    torch.save(model.state_dict(), "model_state.pth")
    print("Model state saved.")

In [7]:
parent_path = 'saves'
model_path = f'{parent_path}/model'
var_path = f'{parent_path}/vars'

varaibles_saved = incremental_save(var_path)
with open(varaibles_saved, "wb") as f:
    pickle.dump(labels_map, f)
    pickle.dump(train_loss, f)
    pickle.dump(val_loss, f)
    pickle.dump(train_acc, f)
    pickle.dump(val_acc, f)

model_saved = incremental_save(model_path, data=model)

Path: saves/vars.0.pth
Saved to saves/model.0.pth.


---