In [1]:
from diffusion_libs import get_network, DiffusionModel, scale_dataset
from samples_generators import fill_vocabulary_c_v2, convert_back_to_code_c_v2, vocabulary_c_v2
import tensorflow as tf
from tensorflow import keras
import numpy as np
import tensorflow as tf

In [2]:
fill_vocabulary_c_v2()
# sampling
min_signal_rate = 0.02
max_signal_rate = 0.95

# architecture
embedding_dims = 32
embedding_max_frequency = 1000.0
embedding_min_frequency = 1.0

# optimization
batch_size = 16
ema = 0.999
learning_rate = 1e-3

# dictionary related
DICTIONARY_SIZE = len(vocabulary_c_v2)
TOKENS_CAPACITY = 256

widths = [32, 64, 64, 96, 128]
block_depth = 2

data_dir = f"./data/simple_c_v2/"
lang_base = f"checkpoints/simple_c_v2"
model_path = f"E:\Studies\master_thesis\codebase\checkpoints\simple_c_v2\cp-0128\model"

In [3]:
network = get_network(
      TOKENS_CAPACITY, embedding_min_frequency, embedding_max_frequency, 
      embedding_dims, widths=widths, block_depth=block_depth, name="complicated"
  )

model = DiffusionModel(
      TOKENS_CAPACITY, DICTIONARY_SIZE, network, batch_size, max_signal_rate, 
      min_signal_rate, ema, False
  )

model.compile(
    optimizer = keras.optimizers.experimental.Adam(
        learning_rate=learning_rate
    ),
    loss = keras.losses.mean_absolute_error
)

#normalizer
n_w = np.load(f"{lang_base}/normalizer_weights.npy", allow_pickle=True)
normalizer = keras.layers.Normalization(mean=n_w[0], variance=n_w[1])
normalizer.build((TOKENS_CAPACITY))
model.normalizer = normalizer
model.load_weights(model_path)

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x18729195d60>

### First should generate some sample to see if it still generates something that looks like code

In [7]:
raw, denormalized = model.generate(5,10)
for sample in denormalized:
  scaled = scale_dataset(sample, DICTIONARY_SIZE)
  print(" ".join(convert_back_to_code_c_v2(scaled)).replace(";", ";\n").replace("{", "{\n"))
  print()

[ 1.0161093 27.000017  14.9999695  4.392937  28.653275  19.244692
 10.001625  28.56835   20.962751  13.507574  18.209211  23.836937
 14.256599  16.679403  19.672861  15.599954  15.915151  18.190376
 18.130323  17.580717  16.289984  18.566893  17.63468   18.292414
 21.670856  18.056206  19.100237  18.2239    18.786673  17.73656
 17.5303    21.24002   18.504879  17.884611  18.400661  19.111486
 18.737213  16.357046  19.407345  18.712608  22.107601  18.875547
 22.66751   18.382893  21.263096  17.832771  16.88466   19.154789
 21.651085  18.707788  21.226793  19.836119  18.09608   19.241913
 19.455513  19.22755   21.325056  20.172117  18.888924  19.585308
 19.414227  22.768976  19.64555   19.196428  19.565231  17.828604
 19.886705  21.268661  19.448553  23.578424  20.134567  17.656477
 20.319443  19.345291  23.832159  19.88308   23.749313  19.447647
 23.47402   18.710468  23.928694  18.843187  17.880508  19.882515
 17.737265  21.250614  18.63592   24.528435  20.303911  24.257282
 19.488567 

In [5]:
print(vocabulary_c_v2[26])

EMPTY
