In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import k_diffusion as K
from k_diffusion.callback import SampleCallback
from k_diffusion.model_loading_utils import load_model, infer_ckpt_path
from k_diffusion.config import SampleCallbackConfig
from pathlib import Path
import torch

# model_id = "htkunb2s"
config = SampleCallbackConfig(
    model_id = "xtb06syu",
    model_step = 115000 ,
    model_dir = Path("/shared/amyxlu/kdplaid/"),
    save_to_disk = False,
    log_to_wandb = False,
)
device = torch.device("cuda:2")
print(config)

SampleCallbackConfig(solver_type=<SampleSolverType.LMS: 'lms'>, seq_len=128, use_ema=True, batch_size=32, n_to_sample=128, n_to_construct=-1, num_recycles=4, sigma_max=0.01, sigma_min=1000.0, rho=7.0, n_steps=15, model_id='xtb06syu', model_step=115000, model_dir=PosixPath('/shared/amyxlu/kdplaid'), device_id=0, save_to_disk=False, log_to_wandb=False, calc_perplexity=True, base_artifact_dir='/shared/amyxlu/kdplaid', sequence_decode_temperature=1.0, calc_fid=True, clip_range=(-1, 1))


In [4]:
ckpt_path = infer_ckpt_path(config.model_dir, config.model_id, config.model_step)
config.model_step = int(Path(ckpt_path).name.split(".")[0])
model, inner_model, model_config = load_model(ckpt_path)

Loading checkpoint from /shared/amyxlu/kdplaid/checkpoints/xtb06syu/00115000.pth


In [20]:
sampler = SampleCallback(
    model, config, model_config, is_wandb_setup=False, device=device
)

# Sample Latents

In [38]:
print("Sampling latent...")
unclipped_sampled_latent, unclipped_sampled_raw = sampler.sample_latent(
    clip_range=config.clip_range, save=False, log_wandb_stats=config.log_to_wandb, return_raw=True
)
unclipped_sampled_raw_expanded = inner_model.project_to_input_dim(unclipped_sampled_raw.to(device))


Sampling latent...


  0%|          | 0/15 [00:00<?, ?it/s]

Sampled latent mean: tensor(-2.0136e+08) std: tensor(8.7010e+09)
Raw sampled latent mean: tensor(-1129329.1250) std: tensor(69186632.)


In [24]:
print("Sampling latent...")
sampled_latent, sampled_raw = sampler.sample_latent(
    clip_range=(-1, 1), save=False, log_wandb_stats=config.log_to_wandb, return_raw=True
)
sampled_raw_expanded = inner_model.project_to_input_dim(sampled_raw.to(device))


Sampling latent...


  0%|          | 0/15 [00:00<?, ?it/s]

Sampled latent mean: tensor(0.4277) std: tensor(89.3935)
Raw sampled latent mean: tensor(-0.0129) std: tensor(0.8801)


# Examine Clipped-During-Sampling

In [32]:
print("Constructing sequences...")
_, _, strs = sampler.construct_sequence(
    sampled_latent,
    calc_perplexity=config.calc_perplexity,
    save_to_disk=config.save_to_disk,
    log_to_wandb=config.log_to_wandb,
)
for s in strs[:10]: print(s)


Constructing sequences...
percentage similarty to argmax idx: 0.898
Perplexity: 8.361
PIPPTPITPPPTPTPIPMTTPTITTIIIIITTIITTIIPTITPTIIIIIIIIIYIIPIIPIPIIIIIITTTIEIIIPIIPITPPPIITPPPPPTPPTTPTPTPPTPPIIITIPTPPITITIIPIPITP
TTTTTTTTTTTTTTTTTTMGTTTTTTTTTTTTTTTTTTTTTTGTTTTTTTTTTTVRTRFTTTTTTTTTTTTTTTTTTTTTTTTTTTRTTTFFTTVTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT
DFFVTTLVFKPFFVTFVTVTFTTTTTTTTTTTTMTTTTTTTTTTTTTTTTTTTTMTTTTTTTTTTTTTTTTTTTTVVTTTTTRTGRRPTTVQRFTVTTVTFGRFPFRTVTTTVVGTTGCTTGVRGGTG
IIQTTEVQQQITQQQFIQDILQIIFQIIQIIIISTCISIIIEVIQVQIIIIITTTTITIIIIITITITTIYTIITTITTTTTTIIIIITTQITQITQIQITQIIQDYYVDKTQIQIQATTQTITFTTQ
SQKRSGSSGKTSHIKSSHSASSSAATIHMAAHRTKARATSAASSASRKASARGASSTQAKCTTKAKTKAKKVAKWTKDIGIIADSMMIIASAASAAAKCASAAAAAATEASSAATASTSSELASSKSP
PPKPTLPTSKKPPPWLIPPPPTPYYPPSPPPKPPVLLPPKPKQPPPYKNTPPPLKPPPPKKPTTIAAPPACSAKTTAGAIATAITAAAATAAIAAPIIASIAPTIAIYMPKITPIPPKPGPPIIPPTA
QQGIQIHQIYVTYQGIGQQIVIQIKYMAIIQKYAGTESIGGYVLTYYESVVVAVYIVAVLYLIYIQVVIYVIIISVYYYWYYYYYLYYLYLYLVLLILILYQIVIVYIITQCAIIYLAYQIIIETIMY
TTTLTTSSTTE

In [33]:
print("Constructing sequences from unnormalized latent...")
_, _, strs = sampler.construct_sequence(
    sampled_raw_expanded,
    calc_perplexity=config.calc_perplexity,
    save_to_disk=config.save_to_disk,
    log_to_wandb=config.log_to_wandb,
    wandb_prefix="raw_unnormalized_"
)
for s in strs[:10]: print(s)


Constructing sequences from unnormalized latent...
percentage similarty to argmax idx: 0.063
Perplexity: 22.722
QCKMVMFMNFHPTNMHFHEIIICPMNNSDHSCGCNISNDLYDHTHSIFERYLEVEIGEIGCIWNIKVTVLKLYEKQFHHQEKQGLHFWGWGLHHTHNHSDIELPKFGSDKLQCVAGVPWIYNIVWSCG
HQFIWFIYQVYSFCGCFFKDSMEIMGVMVMDVGFGYFMCGMCRQWMMMKQDCGNTRDIKCAAEMIELSDINPHPWLRCNKYVWNVTGGGKGDTRSGRYAALNAFRCKPVGLPGYHMDVTVRHSHIEMM
AITMDGRGNNGYAHNKQNALHHGWCCDGERNQYMHKWHKKGQTGSMQFALQHANQQYNFDYTPCMRKDKEKTCDVFDKRFHICTHLLLFSSTVDKVSASAQYDLATPNYEGKGMYGELNRLHASFDPF
FPQTLDVCGCGQYPNSTNSDDHFSNFDEFWIPLFHAKMFPHCLVHWSERHKQMRCPPIWGLLTQLMHEMSFVSVICTSQEMAFWKFLIKNYEAKFNPDAFRKRFGMVTSNHYTIPDLAAIVVLMYLNT
MCDVFFTHVWANGITNRKMKWLPLFYVFGHFWSYCNPMMKFLPATIVLNRFWSFHRTLTTYHTYRKWVAPDHMPTYEISDQLLGPNDRNLRSPGPMMKCVYYFCNGAIVVHLWHPMIVENASPTSWPK
GHPVESWSTCQKTSVIHDGMMTPKRRVPAACLRPFCHSFRNYLTWSKRVPHICINAKKARMLFFYKPGKCHHINIMSLGIMIVDRGNRWDYRGRETPSRWFFGAIDAHQHYWWRKLGWMMNEDICGHQ
YYGRKVDKMRTERSVKGNLNLSQKEPSAPMRDYVIVCRSQFDVCLRSAMTIWCMCWEGCSMLSIYVECPEQYFNFTAGCKVNRPMRTRNCFIHLDAQFICIRCQRWIWPKLLMK

In [35]:
# print("Constructing structures...")
# _, metrics = sampler.construct_structure(
#     sampled_latent,
#     strs,
#     save_to_disk=config.save_to_disk,
#     log_to_wandb=config.log_to_wandb,
# )

In [41]:
fid, kid = sampler.calculate_fid(
    sampled_latent, log_to_wandb=config.log_to_wandb
)
print("Normalized", f"fid: {fid}, kid: {kid}")
fid, kid = sampler.calculate_fid(
    sampled_raw_expanded, log_to_wandb=config.log_to_wandb, wandb_prefix="raw_unnormalized_"
)
print("Unnormalized", f"fid: {fid}, kid: {kid}")


Normalized fid: 1632370.0, kid: 15257927680.0
Unnormalized fid: 5368259.0, kid: 145586667520.0


# Examine sequences for unclipped

In [39]:
print("Constructing sequences...")
_, _, strs = sampler.construct_sequence(
    unclipped_sampled_latent,
    calc_perplexity=config.calc_perplexity,
    save_to_disk=config.save_to_disk,
    log_to_wandb=config.log_to_wandb,
)
for s in strs[:10]: print(s)

Constructing sequences...
percentage similarty to argmax idx: 1.000
Perplexity: 7.510
DDDDDDDDDDDDDDDDDDDDDDCDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDHDDDDDDDDDHDDDDDDDDDHDADHDAADADAAADADACDDHAADDADHDDDADDAADDAADDDAADDD
MCCAMCMCAADDAMACCCCCDAAACMAMCCCDACAMCCCCAMCCAACCAADMCADCACCACDACMAACADCDDDAACCCDDAAADACVAVVAAACAAVAAVAVADDDAVVAVDAAAVAVAVVAAAAAA
DDDAADAADDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDIIDIDDAADADDDDDDD
DDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD
DDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDDADDDDDDDDDDDDADDDDDDDDADDDDADDDDDDDDDDDDDY
DAADADDDAAAAAAAAAAAAAAADDDAADAAAADADDDAAAADADAAAAAAAAAAAAAAAAAAAAAADAAAAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADDADAADDDDDDDDDDDDDDDD
FDSFFNFFNNDNNFFNFFNFFFFFDNNFFFFDNNFFFNNDNNFNNNNNFFNFDFDNFDNDDNFFFHHDFFFDHHHDHHHFHFHHHFDHHDHHHDNHDFDHHHDDHDHDDDHDHHVAVDDDPDMDMMGP
DDDDDDDDDDD

In [42]:
fid, kid = sampler.calculate_fid(
    unclipped_sampled_latent, log_to_wandb=config.log_to_wandb
)
print("Normalized", f"fid: {fid}, kid: {kid}")
fid, kid = sampler.calculate_fid(
    unclipped_sampled_raw_expanded, log_to_wandb=config.log_to_wandb, wandb_prefix="raw_unnormalized_"
)
print("Unnormalized", f"fid: {fid}, kid: {kid}")


Normalized fid: 6.0373738891454446e+22, kid: nan
Unnormalized fid: 1.0907224067014656e+18, kid: nan


# Adjust sigma

In [7]:
config = SampleCallbackConfig(
    model_id = "xtb06syu",
    model_step = 115000 ,
    model_dir = Path("/shared/amyxlu/kdplaid/"),
    save_to_disk = False,
    log_to_wandb = False,
    sigma_min = 1e-2,
    sigma_max = 80,
    n_to_sample = 256,
    batch_size = 256
)
sampler = SampleCallback(
    model, config, model_config, is_wandb_setup=False, device=device,
)

print("Sampling latent...")
sampled_latent, sampled_raw, _ = sampler.sample_latent(
    clip_range=(-1, 1), save=False, log_wandb_stats=config.log_to_wandb, return_raw=True
)
sampled_raw_expanded = inner_model.project_to_input_dim(sampled_raw.to(device))

fid, kid, _ = sampler.calculate_fid(
    sampled_latent, log_to_wandb=config.log_to_wandb
)
print("Normalized", f"fid: {fid}, kid: {kid}")
fid, kid, _ = sampler.calculate_fid(
    sampled_raw_expanded, log_to_wandb=config.log_to_wandb, wandb_prefix="raw_unnormalized_"
)
print("Unnormalized", f"fid: {fid}, kid: {kid}")



Sampling latent...


  0%|          | 0/15 [00:00<?, ?it/s]

Sampled latent mean: tensor(1.7744) std: tensor(85.5328)
Raw sampled latent mean: tensor(0.0015) std: tensor(0.2282)
Normalized fid: 238448.96875, kid: 60487729152.0
Unnormalized fid: 5540334.5, kid: 159120441344.0
