In [59]:
import numpy as np
from tensorflow.keras.preprocessing.sequence import pad_sequences

In [60]:
# Load SketchRNN data from .npz files
def load_data(npz_file):
    data = np.load(npz_file, allow_pickle=True, encoding='latin1')
    num_row = int(len(data['train']) * 0.1)
    return data['train'][:num_row]

In [61]:
# Load datasets
mountain_data = load_data('sketchrnn/mountain.full.npz')
house_data = load_data('sketchrnn/house.full.npz')
cake_data = load_data('sketchrnn/cake.full.npz')
strawberry_data = load_data('sketchrnn/strawberry.full.npz')
candle_data = load_data('sketchrnn/candle.full.npz')
necklace_data = load_data('sketchrnn/necklace.full.npz')
rain_data = load_data('sketchrnn/rain.full.npz')
umbrella_data = load_data('sketchrnn/umbrella.full.npz')
vase_data = load_data('sketchrnn/vase.full.npz')
hat_data = load_data('sketchrnn/hat.full.npz')
lollipop_data = load_data('sketchrnn/lollipop.full.npz')
eye_data = load_data('sketchrnn/eye.full.npz')
fish_data = load_data('sketchrnn/fish.full.npz')
flower_data = load_data('sketchrnn/flower.full.npz')      

In [62]:
# Combine datasets
combined_data = np.concatenate([mountain_data, house_data, cake_data, strawberry_data, candle_data, 
necklace_data, rain_data, umbrella_data, vase_data, hat_data, lollipop_data, eye_data, fish_data, flower_data])

In [63]:
print(combined_data.shape)
print(combined_data[9].shape)

(166898,)
(25, 3)


In [64]:
# Define maximum sequence length
max_seq_len = max(len(sketch) for sketch in combined_data)

In [65]:
max_seq_len

148

In [66]:
# Pad sequences to the same length and extract features
def preprocess_data(data, max_seq_len):
    sequences = []
    for sketch in data:
        sequence = []
        for i in range(len(sketch)):
            dx, dy, state = sketch[i]
            sequence.append([dx, dy, state])
        sequences.append(sequence)
    padded_sequences = pad_sequences(sequences, maxlen=max_seq_len, padding='post', dtype='float32')
    return padded_sequences

In [67]:
train_sequences = preprocess_data(combined_data, max_seq_len)

In [68]:
print(train_sequences.shape)
print(train_sequences[9].shape)
print(train_sequences[9])

(166898, 148, 3)
(148, 3)
[[ 195.   15.    0.]
 [ 223.    2.    1.]
 [-360.  -25.    0.]
 [  30.  -97.    0.]
 [  23.  -57.    0.]
 [  87. -183.    0.]
 [   7.   -9.    0.]
 [   4.    0.    0.]
 [  13.   28.    0.]
 [  16.   61.    0.]
 [  35.  154.    0.]
 [  17.  101.    0.]
 [   9.   31.    1.]
 [-162. -281.    0.]
 [   8.   37.    0.]
 [   7.    2.    0.]
 [  18.  -23.    0.]
 [   4.    0.    0.]
 [  16.   18.    0.]
 [   5.    1.    0.]
 [  16.   -5.    0.]
 [  21.    6.    0.]
 [  15.   -6.    0.]
 [   8.   -5.    0.]
 [   2.   -5.    1.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0.    0.    0.]
 [   0. 

## Augment

In [69]:
import tensorflow as tf

def augment_data(sequences):
    # Example of augmenting sequences by adding Gaussian noise
    noise_factor = 0.05
    augmented_sequences = sequences + noise_factor * np.random.randn(*sequences.shape)
    return augmented_sequences

In [70]:
# Apply data augmentation
augmented_sequences = augment_data(train_sequences)

In [71]:
print(augmented_sequences.shape)
print(augmented_sequences[9].shape)
print(augmented_sequences[9])

(166898, 148, 3)
(148, 3)
[[ 1.94928346e+02  1.50122256e+01 -1.25408160e-02]
 [ 2.22956016e+02  2.03471291e+00  9.77923943e-01]
 [-3.59967019e+02 -2.50050592e+01 -2.41517317e-02]
 [ 2.99782552e+01 -9.70082217e+01 -3.85676491e-03]
 [ 2.29423195e+01 -5.69528732e+01 -4.55904536e-02]
 [ 8.69898662e+01 -1.82985327e+02  1.62803400e-02]
 [ 7.06917458e+00 -9.02912892e+00 -2.68939936e-02]
 [ 3.98899286e+00  1.55364706e-03  7.42428608e-02]
 [ 1.30268178e+01  2.80398315e+01  5.59214962e-02]
 [ 1.60720190e+01  6.09441784e+01  6.59994863e-03]
 [ 3.50678108e+01  1.53934570e+02  2.27696011e-02]
 [ 1.70419621e+01  1.00975475e+02 -3.40000991e-02]
 [ 9.01679284e+00  3.10108855e+01  1.05328960e+00]
 [-1.61985206e+02 -2.81010169e+02 -3.42991604e-02]
 [ 7.97587803e+00  3.69824942e+01  6.23667183e-02]
 [ 7.03657467e+00  2.01089171e+00  1.05694711e-02]
 [ 1.80520952e+01 -2.30603799e+01 -4.18957151e-02]
 [ 4.00248617e+00  6.30074462e-02  1.05923627e-02]
 [ 1.59323312e+01  1.80492039e+01  3.34389243e-02]
 [ 5.

## Modelling

In [72]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Flatten, LSTM, Dense, TimeDistributed

In [73]:
# Define the model architecture
model = Sequential([
    TimeDistributed(Conv1D(filters=64, kernel_size=3, activation='relu'), input_shape=(max_seq_len, 3, 1)),
    TimeDistributed(MaxPooling1D(pool_size=1)),
    TimeDistributed(Flatten()),
    LSTM(256, return_sequences=True),
    LSTM(256),
    Dense(14, activation='softmax')  
])

# Compile the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.summary()

## Train

In [74]:
# Reshape data to add a channel dimension for CNN
train_sequences = np.expand_dims(train_sequences, axis=-1)
augmented_sequences = np.expand_dims(augmented_sequences, axis=-1)

In [75]:
# Assuming you have labels for the combined data
combined_labels = np.concatenate([np.zeros(len(mountain_data)), np.ones(len(house_data)), np.full(len(cake_data), 2),
np.full(len(strawberry_data), 3), np.full(len(candle_data), 4), np.full(len(necklace_data), 5), np.full(len(rain_data), 6),
np.full(len(umbrella_data), 7), np.full(len(vase_data), 8), np.full(len(hat_data), 2), np.full(len(lollipop_data), 9), 
np.full(len(eye_data), 10), np.full(len(fish_data), 11), np.full(len(flower_data), 12)])

In [76]:
# Train the model
model.fit(augmented_sequences, combined_labels, epochs=10, batch_size=64, validation_split=0.2)

Epoch 1/10
[1m2087/2087[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2497s[0m 1s/step - accuracy: 0.4510 - loss: 1.6164 - val_accuracy: 0.1190 - val_loss: 9.3060
Epoch 2/10
[1m2087/2087[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2986s[0m 1s/step - accuracy: 0.8947 - loss: 0.3332 - val_accuracy: 0.2077 - val_loss: 8.5599
Epoch 3/10
[1m2087/2087[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2174s[0m 1s/step - accuracy: 0.9326 - loss: 0.2200 - val_accuracy: 0.2170 - val_loss: 8.8300
Epoch 4/10
[1m2087/2087[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1854s[0m 888ms/step - accuracy: 0.9487 - loss: 0.1686 - val_accuracy: 0.2082 - val_loss: 9.1623
Epoch 5/10
[1m2087/2087[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1744s[0m 836ms/step - accuracy: 0.9601 - loss: 0.1335 - val_accuracy: 0.2315 - val_loss: 9.2300
Epoch 6/10
[1m 820/2087[0m [32m━━━━━━━[0m[37m━━━━━━━━━━━━━[0m [1m19:39[0m 931ms/step - accuracy: 0.9659 - loss: 0.1103

KeyboardInterrupt: 