In [1]:
import pytest
import json
from pathlib import Path

# Adjust the import path based on your project structure
# This assumes 'tests' is at the same level as 'models'
from models.llama.config import ModelConfig

def small_model_config_dir(tmp_path: Path) -> str:
    """
    Creates a temporary directory with a config.json for a small model.
    Returns the path to the directory as a string.
    """
    config_data = {
        "hidden_size": 64,
        "num_hidden_layers": 2,
        "num_attention_heads": 4,
        "num_key_value_heads": 2,
        "intermediate_size": 128,
        "vocab_size": 1000,
        "rms_norm_eps": 1e-6,
        "rope_theta": 1000.0,
        "max_position_embeddings": 512,
        "hidden_act": "silu"
    }
    
    config_path = tmp_path / "config.json"
    with open(config_path, 'w') as f:
        json.dump(config_data, f)
        
    return str(tmp_path)

def test_model_config_loading(small_model_config_dir: str):
    """
    Tests loading a model configuration from a JSON file.
    """
    # Load the configuration using the class method
    config = ModelConfig.from_json_file(small_model_config_dir)

    # Assert that all attributes are loaded correctly
    assert config.dim == 64
    assert config.n_layers == 2
    assert config.n_heads == 4
    assert config.n_kv_heads == 2
    assert config.ffn_hidden_dim == 128
    assert config.vocab_size == 1000
    assert config.rms_norm_eps == 1e-6
    assert config.rope_theta == 1000.0
    assert config.max_seq_len == 512
    assert config.activation_fn == "silu"

    # Assert that the calculated property is correct
    assert config.head_dim == 16 # (64 / 4)

def test_gqa_validation():
    """
    Tests that the GQA constraint (n_heads % n_kv_heads == 0) is enforced.
    """
    with pytest.raises(ValueError, match="must be divisible by"):
        # This configuration is invalid and should raise an error
        ModelConfig(n_heads=5, n_kv_heads=2)


In [7]:
model_config = ModelConfig.from_json_file("/Users/ammar3.shaikh/Desktop/ReLax/experiments")

In [8]:
from models.llama.config import ModelConfig
from models.llama.model import LLaMA
from utils.kvcache import KVCache

In [9]:
model = LLaMA(model_config)

In [11]:
import jax
rng = jax.random.PRNGKey(0)
tokens = jax.random.randint(rng, (1, 10), 0, 128256)
start_pos = 0
kvcache = KVCache.new(n_layers=2, bsz=1, max_seq_len=1024, kv_heads=2, head_dim=16)
params = model.init(rng, tokens, start_pos, kvcache)

TypeError: mul got incompatible shapes for broadcasting: (1, 10, 2, 8), (1, 10, 4, 8).

In [13]:
from pathlib import Path
from safetensors.flax import safe_open

model_path = "/Users/ammar3.shaikh/Desktop/ReLax/artifacts/weights"
paths = list(Path(model_path).glob('*.safetensors'))
if not paths:
    raise ValueError(f"No .safetensors files found in {model_path}")

for filepath in paths:
    with safe_open(filepath, framework="flax") as f:
        for key in f.keys():
            print(key)



ValueError: No .safetensors files found in /Users/ammar3.shaikh/Desktop/ReLax/artifacts/weights

In [2]:
import jax

@jax.jit
def f(arr, x):
    return arr[x]

arr = jax.numpy.array([1, 2, 3, 4, 5])

print(f(arr, 2))







3


In [3]:
print(f(arr,4))

5
