In [7]:
import os
os.environ["KERAS_BACKEND"] = "jax"

In [8]:
import tensorflow as tf
import numpy as np
import json
from sklearn.model_selection import train_test_split
from tqdm import tqdm

def data_generator(filepath, tokenizer, batch_size=32, max_length=100, test_size=0.2):
    """
    A generator function that yields batches of tokenized and padded sequences and their labels.
    
    Parameters:
    - filepath: Path to the JSONL file.
    - tokenizer: An instance of tf.keras.preprocessing.text.Tokenizer.
    - batch_size: The number of samples to return in each batch.
    - max_length: The maximum length of the sequences after padding.
    - test_size: The proportion of the dataset to include in the test split.
    
    Yields:
    - A tuple (batch_sequences, batch_labels), where:
        - batch_sequences is a numpy array of tokenized and padded sequences.
        - batch_labels is a numpy array of labels for each sequence in the batch.
    """
    titles = []
    view_counts = []

    with open(filepath, 'r', encoding='utf-8') as file:
        for line in tqdm(file, desc="Loading and processing data"):
            record = json.loads(line)
            titles.append(record['title'])
            view_counts.append(record['view_count'])
            
            if len(titles) == batch_size:
                sequences = tokenizer.texts_to_sequences(titles)
                padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(sequences, maxlen=max_length, padding='post', truncating='post')
                labels = np.log(np.array(view_counts, dtype=np.float32))
                labels = np.where(labels == -np.inf, 0, labels)
                
                yield padded_sequences, labels
                
                titles = []
                view_counts = []
                
    if titles:
        sequences = tokenizer.texts_to_sequences(titles)
        padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(sequences, maxlen=max_length, padding='post', truncating='post')
        labels = np.log(np.array(view_counts, dtype=np.float32))
        labels = np.where(labels == -np.inf, 0, labels)
        
        yield padded_sequences, labels

def sample_titles(filepath, sample_size=1000):
    """
    Reads a sample of titles from a JSONL file.

    Parameters:
    - filepath: Path to the JSONL file.
    - sample_size: Number of titles to sample.
    
    Returns:
    - A list of sampled titles.
    """
    titles = []
    with open(filepath, 'r', encoding='utf-8') as file:
        for line in file:
            if len(titles) >= sample_size:
                break
            record = json.loads(line)
            titles.append(record['title'])
    return titles


max_length = 100
filepath = '/mnt/datassd/processed_file.jsonl'

titles_sample = sample_titles(filepath)


tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=10000, oov_token="<OOV>")

tokenizer.fit_on_texts(titles_sample)

#tokenizer.fit_on_texts(titles)

In [9]:
import mlflow
import keras
import keras_nlp
from keras import layers

vocab_size = 10000  # Adjust based on your vocabulary size
embedding_dim = 256
max_length = 100  # Adjust based on your titles' maximum length
num_heads = 8  # Number of attention heads in the Transformer encoder
intermediate_dim = 512  # Dimensionality of the encoder's intermediate (feed-forward) layer

# Define input layer
inputs = keras.Input(shape=(max_length,), dtype='int64')

# Token and position embedding layer
embedding_layer = keras_nlp.layers.TokenAndPositionEmbedding(
    vocabulary_size=vocab_size,
    sequence_length=max_length,
    embedding_dim=embedding_dim,
)
x = embedding_layer(inputs)

# Transformer encoder layer
encoder = keras_nlp.layers.TransformerEncoder(
    num_heads=num_heads,
    intermediate_dim=intermediate_dim,
    activation='relu',
    dropout=0.1,
)
x = encoder(x)

# GlobalMaxPooling1D layer for regression task
x = layers.GlobalMaxPooling1D()(x)

# Additional dense layers
x = layers.Dense(256, activation='relu')(x)
outputs = layers.Dense(1, activation='linear')(x)  # Linear activation for regression

# Compile the model
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-4), loss='mean_squared_error')
model.summary()

In [12]:
model = keras.saving.load_model('/home/jay/AI/YT_predict/2.0/YT_Transformer.keras')

predictions = model.predict(data_generator(filepath, tokenizer, batch_size=2))


  return jnp.asarray(x, dtype=dtype)


      1/Unknown [1m3s[0m 3s/step

Loading and processing data: 4it [00:02,  1.36it/s]

    250/Unknown [1m3s[0m 403us/step

Loading and processing data: 502it [00:03, 232.25it/s]

    504/Unknown [1m3s[0m 399us/step

Loading and processing data: 1010it [00:03, 533.41it/s]

    757/Unknown [1m3s[0m 398us/step

Loading and processing data: 1516it [00:03, 902.28it/s]

   1013/Unknown [1m3s[0m 396us/step

  labels = np.log(np.array(view_counts, dtype=np.float32))
Loading and processing data: 2028it [00:03, 1341.59it/s]

   1291/Unknown [1m3s[0m 389us/step

Loading and processing data: 2584it [00:03, 1887.93it/s]

   1564/Unknown [1m4s[0m 385us/step

Loading and processing data: 3130it [00:03, 2444.00it/s]

   1838/Unknown [1m4s[0m 382us/step

Loading and processing data: 3678it [00:03, 2994.32it/s]

   2111/Unknown [1m4s[0m 380us/step

Loading and processing data: 4224it [00:03, 3501.65it/s]

   2381/Unknown [1m4s[0m 380us/step

Loading and processing data: 4764it [00:03, 3928.01it/s]

   2660/Unknown [1m4s[0m 378us/step

Loading and processing data: 5322it [00:03, 4326.47it/s]

   2940/Unknown [1m4s[0m 376us/step

Loading and processing data: 5882it [00:04, 4653.82it/s]

   3220/Unknown [1m4s[0m 374us/step

Loading and processing data: 6442it [00:04, 4903.49it/s]

   3501/Unknown [1m4s[0m 373us/step

Loading and processing data: 7004it [00:04, 5099.03it/s]

   3780/Unknown [1m4s[0m 372us/step

Loading and processing data: 7562it [00:04, 5227.08it/s]

   4063/Unknown [1m4s[0m 371us/step

Loading and processing data: 8128it [00:04, 5345.63it/s]

   4306/Unknown [1m5s[0m 373us/step

Loading and processing data: 8685it [00:04, 5183.89it/s]

   4546/Unknown [1m5s[0m 376us/step

Loading and processing data: 9220it [00:04, 5056.78it/s]

   4786/Unknown [1m5s[0m 378us/step

Loading and processing data: 9738it [00:04, 4972.85it/s]

   5025/Unknown [1m5s[0m 380us/step

Loading and processing data: 10244it [00:04, 4910.21it/s]

   5266/Unknown [1m5s[0m 381us/step

Loading and processing data: 10741it [00:04, 4879.88it/s]

   5506/Unknown [1m5s[0m 383us/step

Loading and processing data: 11233it [00:05, 4851.43it/s]

   5745/Unknown [1m5s[0m 385us/step

Loading and processing data: 11721it [00:05, 4829.04it/s]

   5986/Unknown [1m5s[0m 386us/step

Loading and processing data: 12206it [00:05, 4799.24it/s]

   6339/Unknown [1m5s[0m 388us/step

Loading and processing data: 12688it [00:05, 4763.35it/s]

   6594/Unknown [1m5s[0m 388us/step

Loading and processing data: 13194it [00:05, 4848.76it/s]

   6860/Unknown [1m6s[0m 388us/step

Loading and processing data: 13724it [00:05, 4979.79it/s]

   7126/Unknown [1m6s[0m 388us/step

Loading and processing data: 14254it [00:05, 5072.59it/s]

   7392/Unknown [1m6s[0m 387us/step

Loading and processing data: 14786it [00:05, 5141.08it/s]

   7655/Unknown [1m6s[0m 387us/step

Loading and processing data: 15312it [00:05, 5174.03it/s]

   7911/Unknown [1m6s[0m 387us/step

Loading and processing data: 15830it [00:06, 5153.09it/s]

   8165/Unknown [1m6s[0m 387us/step

Loading and processing data: 16346it [00:06, 5128.26it/s]

   8429/Unknown [1m6s[0m 387us/step

Loading and processing data: 16872it [00:06, 5166.10it/s]

   8693/Unknown [1m6s[0m 387us/step

Loading and processing data: 17400it [00:06, 5197.38it/s]

   8947/Unknown [1m6s[0m 387us/step

Loading and processing data: 17920it [00:06, 5156.57it/s]

   9208/Unknown [1m7s[0m 387us/step

Loading and processing data: 18442it [00:06, 5173.11it/s]

   9467/Unknown [1m7s[0m 387us/step

Loading and processing data: 18960it [00:06, 5160.12it/s]

   9731/Unknown [1m7s[0m 387us/step

Loading and processing data: 19486it [00:06, 5188.87it/s]

   9986/Unknown [1m7s[0m 387us/step

Loading and processing data: 20005it [00:06, 5158.64it/s]

  10250/Unknown [1m7s[0m 387us/step

Loading and processing data: 20530it [00:06, 5185.06it/s]

  10518/Unknown [1m7s[0m 387us/step

Loading and processing data: 21064it [00:07, 5227.16it/s]

  10775/Unknown [1m7s[0m 387us/step

Loading and processing data: 21587it [00:07, 5196.65it/s]

  11032/Unknown [1m7s[0m 387us/step

Loading and processing data: 22107it [00:07, 5171.38it/s]

  11288/Unknown [1m7s[0m 387us/step

Loading and processing data: 22625it [00:07, 5154.73it/s]

  11548/Unknown [1m7s[0m 387us/step

Loading and processing data: 23144it [00:07, 5161.19it/s]

  11808/Unknown [1m8s[0m 387us/step

Loading and processing data: 23664it [00:07, 5170.02it/s]

  12074/Unknown [1m8s[0m 387us/step

Loading and processing data: 24194it [00:07, 5204.85it/s]

  12340/Unknown [1m8s[0m 387us/step

Loading and processing data: 24724it [00:07, 5232.04it/s]

  12606/Unknown [1m8s[0m 386us/step

Loading and processing data: 25255it [00:07, 5255.30it/s]

  12868/Unknown [1m8s[0m 386us/step

Loading and processing data: 25781it [00:07, 5238.10it/s]

  13124/Unknown [1m8s[0m 386us/step

Loading and processing data: 26305it [00:08, 5198.62it/s]

  13383/Unknown [1m8s[0m 386us/step

Loading and processing data: 26825it [00:08, 5193.90it/s]

  13647/Unknown [1m8s[0m 386us/step

Loading and processing data: 27352it [00:08, 5212.81it/s]

  13909/Unknown [1m8s[0m 386us/step

Loading and processing data: 27874it [00:08, 5211.64it/s]

  14171/Unknown [1m8s[0m 386us/step

Loading and processing data: 28396it [00:08, 5209.82it/s]

  14434/Unknown [1m9s[0m 386us/step

Loading and processing data: 28922it [00:08, 5222.70it/s]

  14698/Unknown [1m9s[0m 386us/step

Loading and processing data: 29450it [00:08, 5235.26it/s]

  14962/Unknown [1m9s[0m 386us/step

Loading and processing data: 29978it [00:08, 5243.25it/s]

  15223/Unknown [1m9s[0m 386us/step

Loading and processing data: 30503it [00:08, 5233.64it/s]

  15484/Unknown [1m9s[0m 386us/step

Loading and processing data: 31027it [00:08, 5217.72it/s]

  15740/Unknown [1m9s[0m 386us/step

Loading and processing data: 31549it [00:09, 5186.81it/s]

  16001/Unknown [1m9s[0m 386us/step

Loading and processing data: 32070it [00:09, 5188.24it/s]

  16264/Unknown [1m9s[0m 386us/step

Loading and processing data: 32592it [00:09, 5196.01it/s]

  16525/Unknown [1m9s[0m 386us/step

Loading and processing data: 33116it [00:09, 5204.70it/s]

  16787/Unknown [1m9s[0m 386us/step

Loading and processing data: 33640it [00:09, 5211.25it/s]

  17049/Unknown [1m10s[0m 386us/step

Loading and processing data: 34162it [00:09, 5207.22it/s]

  17314/Unknown [1m10s[0m 386us/step

Loading and processing data: 34692it [00:09, 5229.19it/s]

  17579/Unknown [1m10s[0m 386us/step

Loading and processing data: 35220it [00:09, 5243.57it/s]

  17839/Unknown [1m10s[0m 386us/step

Loading and processing data: 35745it [00:09, 5218.84it/s]

  18102/Unknown [1m10s[0m 386us/step

Loading and processing data: 36268it [00:09, 5220.39it/s]

  18361/Unknown [1m10s[0m 386us/step

Loading and processing data: 36791it [00:10, 5208.10it/s]

  18619/Unknown [1m10s[0m 386us/step

Loading and processing data: 37312it [00:10, 5185.06it/s]

  18879/Unknown [1m10s[0m 386us/step

Loading and processing data: 37831it [00:10, 5182.10it/s]

  19140/Unknown [1m10s[0m 386us/step

Loading and processing data: 38352it [00:10, 5187.90it/s]

  19401/Unknown [1m10s[0m 386us/step

Loading and processing data: 38872it [00:10, 5189.36it/s]

  19663/Unknown [1m11s[0m 386us/step

Loading and processing data: 39391it [00:10, 5096.93it/s]

  19868/Unknown [1m11s[0m 387us/step

Loading and processing data: 39902it [00:10, 4878.38it/s]

  20119/Unknown [1m11s[0m 387us/step

Loading and processing data: 40408it [00:10, 4928.44it/s]

  20368/Unknown [1m11s[0m 387us/step

Loading and processing data: 40904it [00:10, 4934.13it/s]

  20622/Unknown [1m11s[0m 387us/step

Loading and processing data: 41414it [00:10, 4981.49it/s]

  20868/Unknown [1m11s[0m 387us/step

Loading and processing data: 41914it [00:11, 4920.02it/s]

  21106/Unknown [1m11s[0m 388us/step

Loading and processing data: 42407it [00:11, 4866.57it/s]

  21344/Unknown [1m11s[0m 388us/step

Loading and processing data: 42895it [00:11, 4854.29it/s]

  21589/Unknown [1m11s[0m 388us/step

Loading and processing data: 43381it [00:11, 4835.68it/s]

  21832/Unknown [1m11s[0m 389us/step

Loading and processing data: 43874it [00:11, 4860.65it/s]

  22077/Unknown [1m12s[0m 389us/step

Loading and processing data: 44361it [00:11, 4840.66it/s]

  22312/Unknown [1m12s[0m 389us/step

Loading and processing data: 44846it [00:11, 4768.34it/s]

  22545/Unknown [1m12s[0m 390us/step

Loading and processing data: 45324it [00:11, 4769.60it/s]

  22794/Unknown [1m12s[0m 390us/step

Loading and processing data: 45844it [00:11, 4893.87it/s]

  23063/Unknown [1m12s[0m 390us/step

Loading and processing data: 46362it [00:11, 4975.25it/s]

  23309/Unknown [1m12s[0m 390us/step

Loading and processing data: 46860it [00:12, 4957.10it/s]

  23557/Unknown [1m12s[0m 390us/step

Loading and processing data: 47366it [00:12, 4985.54it/s]

  23817/Unknown [1m12s[0m 390us/step

Loading and processing data: 47865it [00:12, 4981.66it/s]

  24055/Unknown [1m12s[0m 390us/step

In [None]:
# Make a line 
x = np.linspace(0,10,100)
y = np.linspace(0,10,100)

In [None]:
#import seaborn as sns
#import matplotlib.pyplot as plt
#
#
#heatmap, xedges, yedges = np.histogram2d(y_test.flatten(), predictions.flatten(), bins=100)
#
#extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
#
#
#plt.imshow(heatmap.T, extent=extent, origin='lower')
#plt.plot(x,y, 'r--')
#plt.xlabel('Views Order of Magnitude')
#plt.ylabel('Predicted Order of Magnitude')
#plt.xlim(2,15)
#plt.ylim(0,17)
#plt.savefig(f'{actual_model_name}_heatmap_bonito.png')
#
#import matplotlib.pyplot as plt
#
#
#plt.scatter(y_test, predictions, alpha=0.1, s=0.5)
#plt.plot(x,y,'r--')
#plt.xlabel('Actual View Count')
#plt.ylabel('Predicted View Count')
#plt.savefig(f'{actual_model_name}_scatter_bonit.png')

In [None]:
# If you need to convert an array of values
y_test_e = np.exp(y_test)  # Assuming y_test was in loge form
y_test_10 = np.log10(y_test_e)

predictions_e = np.exp(predictions)  # Assuming predictions were in loge form
predictions_10 = np.log10(predictions_e)

y_test = y_test_10
predictions = predictions_10

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# Assuming y_test and predictions are available and in log form
# Heatmap
heatmap, xedges, yedges = np.histogram2d(y_test.flatten(), predictions.flatten(), bins=100)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

plt.figure(figsize=(10, 8))
sns.set(style="white")

# Using a colormap (e.g., 'viridis' which is visually appealing and colorblind-friendly)
plt.imshow(heatmap.T, extent=extent, origin='lower', aspect='auto', cmap='viridis')

# Assuming x, y for the red dashed line are defined correctly and correspond to log scale
plt.plot(x, y, 'r--')

plt.xlabel('Log of Actual View Count')
plt.ylabel('Log of Predicted View Count')
plt.colorbar(label='Count of Test')
plt.title('Heatmap of Predictions vs Actual Views')
plt.xlim(0, 9)
plt.ylim(0, 9)

# Adjusting x and y axis to show in 10^ format
ax = plt.gca()
ax.set_xticklabels([f'$10^{{{int(float(label))}}}$' for label in ax.get_xticks()])
ax.set_yticklabels([f'$10^{{{int(float(label))}}}$' for label in ax.get_yticks()])

plt.savefig(f'{actual_model_name}_heatmap_bonito.png')

In [None]:
plt.figure(figsize=(10, 8))
sns.set(style="whitegrid")

# Scatter plot with adjustments for alpha and size for better visibility
plt.scatter(y_test, predictions, alpha=0.2, s=10, cmap='viridis')

plt.plot(x, y, 'r--')  # Assuming x, y for the red dashed line are correct

plt.xlabel('Log of Actual View Count')
plt.ylabel('Log of Predicted View Count')
plt.title('Scatter Plot of Predicted vs Actual Views')
plt.xlim(0, 9)
plt.ylim(0, 9)


# Adjust axis to reflect 10^x and 10^y
ax = plt.gca()
ax.set_xticklabels([f'$10^{{{int(float(label))}}}$' for label in ax.get_xticks()])
ax.set_yticklabels([f'$10^{{{int(float(label))}}}$' for label in ax.get_yticks()])

plt.savefig(f'{actual_model_name}_scatter_bonito.png')