In [2]:
import torch
import os
import numpy as np
import pickle
from tqdm import tqdm   
from torch.utils.data import DataLoader, ConcatDataset

from utils.model_train import train
from utils.processing import load_data, process_data, augment_data, to_tensors, split_batch, incremental_save, partition_datasets
from utils.quickdraw_cnn import QuickDrawCNN_V1, QuickDrawCNN_V2

In [6]:
image_dir = "data/numpy_bitmap/"
datasets_dir = "data/datasets"
categories = os.listdir(image_dir)
categories = categories[:100]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
labels_map = {label: i for i, label in enumerate(set(categories))}
reversed_labels_map = {v: k for k, v in labels_map.items()}

In [7]:
def pipeline(image_dir: str, categories: list, device: torch.device, labels_map: dict):
    
    '''
    A pipeline function to run the entire process of loading, processing, augmenting, and splitting the data onto each category at a time.
    '''
    
    train_datasets, test_datasets, val_datasets = [], [], []

    for cat in tqdm(categories, desc="Processing categories"):
        
        try:
            # load, process, augment, and split the data
            features, label = load_data(image_dir, cat, file_standardize=False)
            features, label = process_data(features, label)
            mask = np.random.rand(len(features)) <= 0.003 # keep 30% of the data
            features = features[mask]
            # features, label = augment_data(features, label, rot=0, h_flip=False, v_flip=False)
            features, labels = to_tensors(features, label, labels_map, device=device)
            
            # split the data into train, test, and validation sets
            train_loader, test_loader, val_loader = split_batch(features, labels, batch_size=32)
            train_datasets.append(train_loader)
            test_datasets.append(test_loader)
            val_datasets.append(val_loader)
        except RuntimeError as e:
            print(f"Error on {cat}: {e}")
        
    return train_datasets, test_datasets, val_datasets

In [8]:
x, y, z = pipeline(image_dir, categories, device, labels_map)


Processing categories: 100%|██████████| 100/100 [02:59<00:00,  1.79s/it]


In [9]:
for i, (train, test, val) in enumerate(zip(x, y, z)):
    model = QuickDrawCNN_V2(num_classes=len(labels_map))
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = torch.nn.CrossEntropyLoss()
    train(model, optimizer, criterion, train, val, device, epochs=10, save_path=f"models/model_{i}.pt")
    print(f"Model {i} trained successfully.")
    del model
    torch.cuda.empty_cache()

[<torch.utils.data.dataset.TensorDataset at 0x1a81e8eae70>,
 <torch.utils.data.dataset.TensorDataset at 0x1a81e5dbcb0>,
 <torch.utils.data.dataset.TensorDataset at 0x1a81e833620>,
 <torch.utils.data.dataset.TensorDataset at 0x1a81e8334d0>,
 <torch.utils.data.dataset.TensorDataset at 0x1a81e78e420>,
 <torch.utils.data.dataset.TensorDataset at 0x1a836cf7b00>,
 <torch.utils.data.dataset.TensorDataset at 0x1a836cf75c0>,
 <torch.utils.data.dataset.TensorDataset at 0x1a836cf6750>,
 <torch.utils.data.dataset.TensorDataset at 0x1a81e8645c0>,
 <torch.utils.data.dataset.TensorDataset at 0x1a836cf73e0>,
 <torch.utils.data.dataset.TensorDataset at 0x1a81e669a30>,
 <torch.utils.data.dataset.TensorDataset at 0x1a836d3cd40>,
 <torch.utils.data.dataset.TensorDataset at 0x1a836d3d2e0>,
 <torch.utils.data.dataset.TensorDataset at 0x1a81e865130>,
 <torch.utils.data.dataset.TensorDataset at 0x1a836d3d280>,
 <torch.utils.data.dataset.TensorDataset at 0x1a836d3d550>,
 <torch.utils.data.dataset.TensorDataset

In [9]:
num_files = 15

train_file_paths = [f"{datasets_dir}/train/train_datasets_{i}.pkl" for i in range(num_files)]
test_file_paths = [f"{datasets_dir}/test/test_datasets_{i}.pkl" for i in range(num_files)]
val_file_paths = [f"{datasets_dir}/val/val_datasets_{i}.pkl" for i in range(num_files)]

In [29]:
# Round robin system to distribute the categories into the files
for i in tqdm(range(0, len(categories), len(categories) // num_files), desc="Processing categorical split"):
    print(f"Processing categories {i} to {min(i + len(categories) // num_files, len(categories))}...")
    
    train_buffers = [[] for _ in range(num_files)]
    test_buffers = [[] for _ in range(num_files)]
    val_buffers = [[] for _ in range(num_files)]
        
    try:
        train_datasets, test_datasets, val_datasets = pipeline(image_dir, categories[i:i + len(categories) // num_files], device, labels_map)
        print(f'processing category {categories[i:i + len(categories) // num_files]}')
        torch.cuda.empty_cache()

        # Distribute datasets into buffers
        for idx, (train_data, test_data, val_data) in enumerate(tqdm(zip(train_datasets, test_datasets, val_datasets), desc="Distributing datasets", total=len(train_datasets))):
            file_idx = idx % num_files
            train_buffers[file_idx].append(train_data)
            test_buffers[file_idx].append(test_data)
            val_buffers[file_idx].append(val_data)

        # Save the buffers
        for file_idx in tqdm(range(num_files), desc="Saving buffers"):
            with open(train_file_paths[file_idx], "wb") as f:
                pickle.dump(train_buffers[file_idx], f)
            with open(test_file_paths[file_idx], "wb") as f:
                pickle.dump(test_buffers[file_idx], f)
            with open(val_file_paths[file_idx], "wb") as f:
                pickle.dump(val_buffers[file_idx], f)

        # Clear buffers to free up cuda memory
        del train_datasets, test_datasets, val_datasets
        del train_data, test_data, val_data
        del train_buffers, test_buffers, val_buffers
        
        
        torch.cuda.empty_cache()

    except RuntimeError as e:
        print(f"Error on {i}: {e}")

Processing categorical split:   0%|          | 0/15 [00:00<?, ?it/s]

Processing categories 0 to 23...


Processing categories:  96%|█████████▌| 22/23 [01:00<00:02,  2.73s/it]
Processing categorical split:   0%|          | 0/15 [01:00<?, ?it/s]


KeyboardInterrupt: 

---

In [46]:
model = QuickDrawCNN_V2().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [27]:
with open(val_file_paths[14], "rb") as f:
    train_datasets = pickle.load(f)

category_counts = {label: 0 for label in labels_map}
for dataset in train_datasets:
    for _, label in dataset:
        category_counts[reversed_labels_map[label.item()]] += 1
        
non_zero_counts = {k: v for k, v in category_counts.items() if v != 0}
print(non_zero_counts)

{'trumpet.npy': 511}


In [49]:
# Iterate trough filed training the model on each before deleting variables and moving to the next
for i in range(num_files):

    # load train and validation datasets
    with open(train_file_paths[i], "rb") as f:
        train_datasets = pickle.load(f)
    with open(val_file_paths[i], "rb") as f:
        val_datasets = pickle.load(f)
        
    # convert datasets to dataloaders
    train_loader = DataLoader(ConcatDataset(train_datasets), batch_size=32, shuffle=True)
    val_loader = DataLoader(ConcatDataset(val_datasets), batch_size=32, shuffle=True)
        
    # train the model
    try:
        train_loss, val_loss, train_acc, val_acc = train(model=model,
                                                        train_loader=train_loader,
                                                        val_loader=val_loader,
                                                        epochs=10,
                                                        criterion=criterion,
                                                        optimizer=optimizer,
                                                        device=device)
    except KeyboardInterrupt:
        print("Training interrupted manually. Saving current model state...")
        torch.save(model.state_dict(), "model_state.pth")
        print("Model state saved.")

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14


In [7]:
parent_path = 'saves'
model_path = f'{parent_path}/model'
var_path = f'{parent_path}/vars'

varaibles_saved = incremental_save(var_path)
with open(varaibles_saved, "wb") as f:
    pickle.dump(labels_map, f)
    pickle.dump(train_loss, f)
    pickle.dump(val_loss, f)
    pickle.dump(train_acc, f)
    pickle.dump(val_acc, f)

model_saved = incremental_save(model_path, data=model)

Path: saves/vars.0.pth
Saved to saves/model.0.pth.


---