In [1]:
from hydra import compose, core, initialize
from hydra.core.global_hydra import GlobalHydra
from pathlib import Path
import os


def compose_config_from_path(config_path, config_name="config"):
    """
    Compose a Hydra-compatible configuration from the specified path.
    
    Args:
        config_path (str): Path to the directory containing config files
        config_name (str, optional): Name of the main config file (without .yaml extension). 
                                   Defaults to "config".
    
    Returns:
        OmegaConf: Composed configuration object
    
    Raises:
        FileNotFoundError: If the config directory or config file doesn't exist
        hydra.errors.HydraException: If there are issues with Hydra configuration
    """
    # Reset Hydra to avoid conflicts
    GlobalHydra.instance().clear()
    
    # Convert to absolute path and validate
    config_path = Path(config_path)
    if not config_path.exists():
        raise FileNotFoundError(f"Config directory not found: {config_path}")
    
    if not config_path.is_dir():
        raise ValueError(f"Config path must be a directory: {config_path}")
    
    # Get absolute path and parent directory
    abs_config_path = config_path.resolve()
    parent_dir = abs_config_path.parent
    relative_config_path = abs_config_path.name
    
    # Save current working directory
    original_cwd = os.getcwd()
    
    try:
        # Change to parent directory and use relative path for Hydra
        os.chdir(parent_dir)
        
        # Initialize Hydra with the relative config path
        initialize(config_path=relative_config_path, version_base=None)
        
        # Compose the configuration
        cfg = compose(config_name=config_name)
        return cfg
        
    except Exception as e:
        # Clean up Hydra instance on error
        GlobalHydra.instance().clear()
        raise e
    finally:
        # Always restore original working directory
        os.chdir(original_cwd)

In [2]:
from encoder_trainer import EncoderTrainer
os.environ['PROJECT_ROOT'] = '/mnt/virtual_ai0001071-01239_SR006-nfs2/afedorov/projects/LatentDiffusion'
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

cfg = compose_config_from_path('./conf')
cfg.autoencoder.training.batch_size=128
cfg.encoder.latent.num_latents=1
cfg.decoder.latent.num_latents=1
trainer = EncoderTrainer(cfg)






Tokenizer: bert-base-cased - BertTokenizer(name_or_path='bert-base-cased', vocab_size=28996, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)






Loading text encoder: bert-base-cased





KeyboardInterrupt: 

In [6]:
import torch
from typing import Dict

trainer.encoder.eval()
trainer.decoder.eval()

total_loss = torch.Tensor([0.0])
valid_dict: Dict[str, torch.Tensor] = dict()
valid_count = torch.Tensor([0.0])

from dataloader import get_dataloaders
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(cfg.autoencoder.model.text_encoder)
_, valid_loader = get_dataloaders(cfg, tokenizer, skip_train=False, skip_valid=False, valid_seed=None)

In [4]:
from omegaconf import OmegaConf
print(OmegaConf.to_yaml(cfg))

autoencoder:
  model:
    checkpoints_prefix: autoencoder
    text_encoder: bert-base-cased
    text_encoder_freeze_params: true
    num_workers: 10
    load_checkpoint: autoencoder/200000.pth
  loss:
    level_weights:
    - 1.0
    - 1.0
    - 1.0
    - 1.0
    - 1.0
    - 1.0
  training:
    training_iters: 200000
    batch_size: 128
    batch_size_per_gpu: 512
  params:
    text_encoder: 0
    encoder: 99275520
    decoder: 126266162
    total: 225541682
  all_params: dict()
  logging:
    log_freq: 10
    eval_freq: 20000
    save_freq: 20000
  optimizer:
    name: stableadam
    learning_rate: 0.0002
    warmup_lr: 1.0e-08
    min_lr: 0.0001
    weight_decay: 1.0e-05
    eps: 1.0e-06
    betas:
    - 0.9
    - 0.98
    linear_warmup: 10
    grad_clip_norm: 10.0
encoder:
  attention:
    head_size: 64
    num_heads: 12
    probs_dropout: 0.0
    qk_norm: true
    implementation: flash_attention_2
  embedding:
    dim: 768
    max_position_embeddings: 128
    initializer_range: 0.0

In [7]:
from tqdm import tqdm

with torch.no_grad():
    for batch in tqdm(valid_loader):
        batch_size = batch["input_ids"].shape[0]
        batch_loss, loss_dict = trainer.calc_loss(batch)
        
        for k in loss_dict:
            if k in valid_dict:
                valid_dict[k] += loss_dict[k] * batch_size
            else:
                valid_dict[k] = torch.Tensor([loss_dict[k] * batch_size])
        valid_count += batch_size

        total_loss += batch_loss.item() * batch_size
        

100%|██████████| 391/391 [00:35<00:00, 11.02it/s]


In [8]:
loss_dict

{'ce_loss': 0.6875,
 'mse_loss': 0.048601116985082626,
 'accuracy': 0.828320324420929,
 'variation_loss': 0.20029935240745544}

In [9]:
valid_dict['accuracy'] / valid_count

tensor([0.8228])

In [1]:
from transformers import AutoModel

embedding_matrix = AutoModel.from_pretrained(
    'bert-base-cased',
    add_pooling_layer=False,
).embeddings.word_embeddings

In [10]:
import torch
l = torch.nn.Linear(768, 1024)
t = embedding_matrix(torch.randint(0, 100, (5, 10)))
l(t).shape, t.shape

(torch.Size([5, 10, 1024]), torch.Size([5, 10, 768]))