In [1]:
import json
from pathlib import Path
from dataclasses import dataclass

from llm4structgen.datasets.base_dataset import BaseDataset
from llm4structgen.datasets.prompts import *
from llm4structgen.llms.llama2_utils import *
from llm4structgen.representations.z_matrix import ZMatrix

In [2]:
@dataclass
class ModelConfig:
    run_name: str
    expdir: Path = Path("exp")
    model_name: str = "13b"
    fp8: bool = True
    lora_rank: int = 8
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    dataset_type: str = "cif"
    data_path: Path = Path("data/mp20-cif/")
    num_epochs: int = 5
    batch_size: int = 2
    gradient_accumulation_steps: int = 1
    lr: float = 5e-4
    lr_scheduler: str = "cosine"
    warmup_ratio: int = 0.03
    num_warmup_steps: int = 100
    weight_decay: float = 0.0
    eval_freq: int = 500
    save_freq: int = 500
    log_freq: int = 1
    format_permute_composition: bool = False
    format_permute_structure: bool = False
    w_attributes: bool = True
    resume_dir: Path = None
    task_probabilities: dict = None
    add_perturbed_example: bool = False

In [3]:
args = ModelConfig(run_name="test")

In [4]:
model = get_model(args, 0)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
tokenizer = get_tokenizer(args)

In [6]:
smart_tokenizer_and_embedding_resize(model, tokenizer)

In [7]:
zmatrix_encoder = ZMatrix()

In [8]:
ds = BaseDataset(
    data_dir="val.json", 
    tokenizer=tokenizer, 
    encoder=zmatrix_encoder, 
    prompt_header=Z_MATRIX_GENERATION_PROMPT_HEADER,
    attributes=False
)

In [10]:
ds[0].keys()

dict_keys(['input_ids', 'labels', 'input_ids_lens', 'labels_lens'])

In [11]:
prompt = tokenizer.decode(ds[0]["input_ids"])

In [13]:
print(prompt)

<s> Below is a description of a bulk material where each atom is described by its element type and three attributes: 1. distance to the previous atom, 2. angle to the previous two atoms, 3. dihedral angle to the previous three atoms. The first three Fm atoms are dummies that help define the rest of the material.  Generate a description of the lengths and angles of the lattice vectors and the three dummy Fm atoms, followed by the element type and the three attributes for each atom within the lattice:
7.22 7.22 5.64
90 90 120
Fm
Fm
2.2
Fm
3.8 30
Y
4.9 92 121
Y
5.0 59 20
Ho
3.6 44 351
Ho
3.6 60 53
Ho
3.6 60 180
Ho
6.2 29 70
Ho
3.6 73 238
Ho
3.6 60 152</s>
