In [1]:
import os

import torch
import wandb
from hydra import compose, initialize
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader

from src.utils.load import load_cdvi_for_bml
from src.train.train_bml import BetterBMLTrainer
from src.train.train_bml_alternating import AlternatingBMLTrainer
from src.utils.visualize import visualize_cdvi_for_bml_test

In [2]:
try:
    import torch_directml
    device = torch_directml.device() 
except ImportError:
    device = torch.device("cpu") 

In [3]:
dir = "../models/test" 

In [None]:
with initialize(version_base=None, config_path=dir):
    cfg = compose(config_name="cfg")

    cdvi, optimizer, train_loader, val_loader = load_cdvi_for_bml(cfg, alternating_ratio=0.2, device=device)

    cdvi_path = f"{dir}/cdvi.pth"
    optim_path = f"{dir}/optim.pth"

    if os.path.exists(cdvi_path):
        cdvi_state_dict = torch.load(cdvi_path, map_location=torch.device("cpu"), weights_only=True)
        cdvi.load_state_dict(cdvi_state_dict)
        print(f"loaded cdvi from {cdvi_path}")

    if os.path.exists(optim_path):
        optim_state_dict = torch.load(optim_path, map_location=torch.device('cpu'), weights_only=True)
        optimizer.load_state_dict(optim_state_dict)
        print(f"loaded optim from {optim_path}")

Generating tasks: 100%|██████████| 8192/8192 [00:00<00:00, 285683.60it/s]


In [5]:
wandb_logging = True

if wandb_logging:
    wandb.init(project="test-12345-alternating")

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmax-burzer[0m ([33mmax-burzer-karlsruhe-institute-of-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
# params = [
#     {"params": dvi_process.parameters(), "lr": config.learning_rate},
#     {"params": set_encoder.parameters(), "lr": config.learning_rate},
#     {"params": decoder.parameters(), "lr": config.learning_rate}
# ]

scheduler = None #ReduceLROnPlateau(optimizer, mode="min", factor=0.3, patience=500)

# trainer = BetterBMLTrainer(
#     device=device,
#     cdvi=cdvi,
#     train_loader=train_loader,
#     val_loader=val_loader,
#     optimizer=optimizer,
#     scheduler=scheduler,
#     wandb_logging=wandb_logging
# )

trainer = AlternatingBMLTrainer(
    device=device,
    cdvi=cdvi,
    train_decoder_loader=train_loader[0],
    train_cdvi_loader=train_loader[1],
    val_loader=val_loader,
    optimizer=optimizer,
    wandb_logging=wandb_logging
)

In [None]:
num_epochs = 100

losses = trainer.train(
    num_epochs=num_epochs,
    max_clip_norm=None, # cfg.training.max_clip_norm,
    alpha=None, # cfg.training.alpha,
    validate=True
)

In [None]:
max_context_size = 6
num_samples = 256
gen_dataloader = DataLoader(val_loader.dataset, 1, True)

ranges = [(-6, 6), (-6, 6)]

targets, samples = visualize_cdvi_for_bml_test(
    device=device,
    cdvi=cdvi,
    dataloader=gen_dataloader,
    config=cfg,
    num_samples=num_samples,
    max_context_size=max_context_size,
    ranges=ranges,
)

In [None]:
torch.save(cdvi.state_dict(), f"{dir}/cdvi_finetuned.pth")