In [1]:
import torch

from v2.usta_model import UstaModel
from v2.usta_tokenizer import UstaTokenizer

device = "cpu"

if torch.cuda.is_available():
  device = "cuda"
elif torch.backends.mps.is_available():
  device = "mps"
  

print(f"Using device: {device}")

u_tokenizer = UstaTokenizer("v1/tokenizer.json")

prompts = [
  "the capital of the united",
  "madrid is in",
  "the capital of france is",
  "the capital of germany is"
]

tokens = u_tokenizer.encode(prompts[0])
tokens = tokens.to(device)
print(tokens)
batch_tokens = u_tokenizer.encode_batch(prompts, 32)
batch_tokens = batch_tokens.to(device)
batch_tokens.shape

Using device: mps
tensor([ 0, 61,  1, 61,  2, 61,  0, 61,  3], device='mps:0')


torch.Size([4, 32])

In [2]:
context_length = 32
torch.manual_seed(1)
u_model = UstaModel(
  vocab_size=len(u_tokenizer.vocab),
  embedding_dim=12,
  num_heads=4,
  context_length=context_length,
  num_layers=8,
  device=device
)

out = u_model(batch_tokens)
out.shape

torch.Size([4, 32, 64])

In [3]:
tokens.unsqueeze(0)

tensor([[ 0, 61,  1, 61,  2, 61,  0, 61,  3]], device='mps:0')

In [4]:
u_model.generate(tokens, 2)

[0, 61, 1, 61, 2, 61, 0, 61, 3, 17, 49]

In [5]:
tokens.shape

torch.Size([9])

In [6]:
# save model
torch.save(u_model.state_dict(), "u_model.pth")

# load model
u_model.load_state_dict(torch.load("u_model.pth"))

<All keys matched successfully>

In [None]:
# %pip install -U git+https://github.com/huggingface/transformers

Collecting git+https://github.com/huggingface/transformers
  Cloning https://github.com/huggingface/transformers to /private/var/folders/z7/wrd0w0hn7pvb9g97kmdn17640000gn/T/pip-req-build-3az1z0u8
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers /private/var/folders/z7/wrd0w0hn7pvb9g97kmdn17640000gn/T/pip-req-build-3az1z0u8
^C
[31mERROR: Operation cancelled by user[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
# Load model directly
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("facebook/vjepa2-vitl-fpc64-256")
model = AutoModel.from_pretrained("facebook/vjepa2-vitl-fpc64-256")

In [3]:
model

VJEPA2Model(
  (encoder): VJEPA2Encoder(
    (embeddings): VJEPA2Embeddings(
      (patch_embeddings): VJEPA2PatchEmbeddings3D(
        (proj): Conv3d(3, 1024, kernel_size=(2, 16, 16), stride=(2, 16, 16))
      )
    )
    (layer): ModuleList(
      (0-23): 24 x VJEPA2Layer(
        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (attention): VJEPA2RopeAttention(
          (query): Linear(in_features=1024, out_features=1024, bias=True)
          (key): Linear(in_features=1024, out_features=1024, bias=True)
          (value): Linear(in_features=1024, out_features=1024, bias=True)
          (proj): Linear(in_features=1024, out_features=1024, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (mlp): VJEPA2MLP(
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (activation): GELUActivation()
          (fc

In [4]:
tokenizer = AutoTokenizer.from_pretrained("facebook/vjepa2-vitl-fpc64-256")

KeyError: <class 'transformers.models.vjepa2.configuration_vjepa2.VJEPA2Config'>