In [None]:
import torch
import wandb
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
from pprint import pprint
from copy import deepcopy

from config import device, operators, DecodeType, DecodeInstruction, SamplingType
from ml.tokenizer import Tokenizer
from ml.fontmodel import DecodeInstruction, FontModel
from ml.performance import PerformanceMetrics
from parsing.glyph_viz import Visualizer
from parsing.tablelist_utils import numbers_first, make_non_cumulative

In [None]:
print(f"Executing train-cff.ipynb on {device}...\n-----------------------------")

args = {
    "load_model": True,
    "train_transformer": True,
    "min_number": -500,
    "max_number": 500,
    "max_seq_len": 5040,
    "num_layers": 12,
    "embedding_dim": 1024,
    "num_heads": 16,
    "ff_dim": 4096,
    "use_wandb": True,
    "epochs": 15,
    "batch_size": 1,
    "batch_accumulate": 4,
    "lr": 6e-4,
    "dropout_rate": 0.2,
    "weight_decay": 1e-1,
    "gradient_clip": True,
    "gradient_clip_val": 1.0,
    "label_smoothing": 0.001,
    "sample_every": 1,
    "use_scheduler": True,
    "scheduler_warmup_steps": 2000,
    "data_type": torch.bfloat16,
    "vae_beta": 1e-1,
    "vae_epochs": 10,
    "vae_lr": 1e-2,
    "vae_weight_decay": 1e-5,
    "freeze_embeddings": False,
    "use_pretrained_embeddings": False,
    "pretrain_embeddings": False,
    "pretrain_epochs": 1,
    "pretrain_batch_size": 128,
    "pretrain_lr": 4e-3,
    "pretrain_use_scheduler": True,
    "pretrain_scheduler_warmup_steps": 3000,
    "use_pretrained_vit_encoder": False,
    "pretrain_vit_encoder": False,
    "pretrain_vit_encoder_epochs": 1,
    "pretrain_vit_encoder_batch_size": 128,
    "pretrain_vit_encoder_batch_accumulate": 1,
    "pretrain_vit_encoder_lr": 1e-3,
    "pretrain_vit_encoder_weight_decay": 1e-3,
    "pretrain_vit_encoder_use_scheduler": True,
    "pretrain_vit_encoder_scheduler_warmup_steps": 1500,
    "post_train": False,
    "post_train_epochs": 1,
    "post_train_batch_size": 32,
    "post_train_lr": 6e-4,
    "post_train_kl_penalty": 0.05,
    "post_train_use_scheduler": True,
    "post_train_scheduler_warmup_steps": 2000,
}

print("Posttraining hyperparameters:")
pprint(args)

In [None]:
pad_token = "<PAD>"
sos_token = "<SOS>"
eos_token = "<EOS>"
tokenizer = Tokenizer(
    min_number=args['min_number'],
    max_number=args['max_number'],
    possible_operators=operators,
    pad_token=pad_token,
    sos_token=sos_token,
    eos_token=eos_token
)
cumulative = True
vocab_size = tokenizer.num_tokens

In [None]:
decode_instr = DecodeInstruction( # NOTE: doesn't matter unless loading from .config.txt fails
    DecodeType.ANCESTRAL,
    SamplingType.MULTINOMIAL,
    max_seq_len=args['max_seq_len'],
    k=5,
    p=0,
    temp=0,
    beam_size=6,
)

In [None]:
models_folder = f'~/models'
if args['load_model']:
    model_pre = torch.load(f'{models_folder}/transformer-basic-33928allchars_centered_scaled_sorted_filtered_cumulative_padded-14.pkl', map_location=device, weights_only=False).to(device)
else:
    model_pre = FontModel(
        num_enc_layers=args['num_layers'],
        num_dec_layers=args['num_layers'],
        vocab_size=vocab_size,
        embedding_dim=args['embedding_dim'],
        num_heads=args['num_heads'],
        ff_dim=args['ff_dim'],
        dropout_rate=args['dropout_rate'],
        max_seq_len=args['max_seq_len'],
        device=device
    ).to(device, dtype=args['data_type'])
model = torch.compile(model_pre)
original_model = deepcopy(model)
original_model.eval()

In [None]:
# Parameters (tentative):
# FontModel: embedder (DON'T APPLY WEIGHT DECAY)
# TransformerDecoder: transformer_decoder_layers (DON'T APPLY WEIGHT DECAY TO RMSNORM), command_encoder, command_decoder, norm_final (DON'T APPLY WEIGHT DECAY)
# TransformerEncoder: transformer_encoder_layers (DON'T APPLY WEIGHT DECAY TO RMSNORM), embedder (custom),pos_embed, norm_final (DON'T APPLY WEIGHT DECAY)

# We don't want to apply weight decay to layer norms and embeddings
no_weight_decay_params = [x for x in model.decoder.embedder.parameters() if x.requires_grad]
no_weight_decay_params += [x for x in model.decoder.inverse_embedder.parameters() if x.requires_grad]
no_weight_decay_params += [x for name, x in model.decoder.transformer_decoder_layers.named_parameters() if x.requires_grad and ('norm' in name or 'bias' in name)]
no_weight_decay_params += [x for x in model.decoder.norm_final.parameters() if x.requires_grad]
no_weight_decay_params += [x for name, x in model.encoder.transformer_encoder_layers.named_parameters() if x.requires_grad and ('norm' in name or 'bias' in name)]
no_weight_decay_params += [x for x in model.encoder.norm_final.parameters() if x.requires_grad]
no_weight_decay_params += [x for name, x in model.encoder.embedder.named_parameters() if x.requires_grad and ('norm' in name or 'bias' in name)]
no_weight_decay_params += [x for x in model.decoder.command_encoder.parameters() if x.requires_grad]
no_weight_decay_params += [x for x in model.decoder.command_decoder.parameters() if x.requires_grad]
no_weight_decay_params += [x for x in model.decoder.command_decoder_2a.parameters() if x.requires_grad]
no_weight_decay_params += [x for x in model.decoder.command_decoder_2b.parameters() if x.requires_grad]
# no_weight_decay_params += [x for x in model.decoder.command_decoder_1.parameters() if x.requires_grad]
# no_weight_decay_params += [x for x in model.decoder.command_decoder_2.parameters() if x.requires_grad]
# no_weight_decay_params += [x for x in model.decoder.W_cn.parameters() if x.requires_grad]
# no_weight_decay_params += [x for x in model.decoder.W_cnb.parameters() if x.requires_grad]

weight_decay_params = [x for name, x in model.decoder.transformer_decoder_layers.named_parameters() if x.requires_grad and 'norm' not in name and 'bias' not in name]
weight_decay_params += [x for name, x in model.encoder.transformer_encoder_layers.named_parameters() if x.requires_grad and 'norm' not in name and 'bias' not in name]
weight_decay_params += [x for name, x in model.encoder.embedder.named_parameters() if x.requires_grad and 'norm' not in name and 'bias' not in name]
weight_decay_params += [x for x in model.encoder.pos_embed.parameters() if x.requires_grad]

vit_encoder_params_nwd = [x for name, x in model.encoder.embedder.named_parameters() if x.requires_grad]# and ('norm' in name or 'bias' in name)]
# vit_encoder_params_nwd += [x for name, x in model.encoder.pretrain_reverse_ae.named_parameters() if x.requires_grad and ('norm' in name or 'bias' in name)]
vit_encoder_params_nwd += [x for name, x in model.encoder.transformer_encoder_layers.named_parameters() if x.requires_grad and ('norm' in name or 'bias' in name)]
vit_encoder_params_nwd += [x for x in model.encoder.norm_final.parameters() if x.requires_grad]
# vit_encoder_params_wd = [x for name, x in model.encoder.embedder.named_parameters() if x.requires_grad and 'norm' not in name and 'bias' not in name]
vit_encoder_params_nwd += [x for name, x in model.encoder.pretrain_reverse_ae.named_parameters() if x.requires_grad]# and 'norm' not in name and 'bias' not in name]
vit_encoder_params_wd = [x for name, x in model.encoder.transformer_encoder_layers.named_parameters() if x.requires_grad and 'norm' not in name and 'bias' not in name]
vit_encoder_params_wd += [x for x in model.encoder.pos_embed.parameters() if x.requires_grad]

optimizer = torch.optim.AdamW(
    [
       {'params': weight_decay_params, 'weight_decay': args['weight_decay']},
       {'params': no_weight_decay_params, 'weight_decay': args['weight_decay']}
    ],
    betas=(0.9, 0.95),
    lr=args['lr'] 
)

max_len = 33928
num_glyphs = 26
step_every = 1

if args['use_scheduler']:
    # scheduler = TransformerScheduler(
    #     optimizer=optimizer,
    #     dim_embed=args['embedding_dim'],
    #     warmup_steps=args['scheduler_warmup_steps']
    # )
    batches_per_epoch = int(max_len * (num_glyphs // step_every) / args['batch_size'] + 0.5)
    scheduler1 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args['epochs'] * (batches_per_epoch // args['batch_accumulate']), eta_min=1e-5)
    scheduler2 = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.001, end_factor=1.0, total_iters=args['scheduler_warmup_steps'])
    scheduler = torch.optim.lr_scheduler.ChainedScheduler([scheduler1, scheduler2], optimizer=optimizer)

dataset_name = f"~/basic-33928allchars_centered_scaled_sorted_filtered{'_cumulative' if cumulative else ''}_padded"
train_start, train_end = 0, int(0.95 * max_len) * num_glyphs
test_start, test_end = train_end, max_len * num_glyphs
# max_len = 5
# train_start, train_end = 0, 26*max_len
# test_start, test_end = 0, 26*max_len
cff_dataset = torch.load(f'./{dataset_name}.pt', mmap=True)[train_start:train_end:step_every]
cff_dataset_test = torch.load(f'./{dataset_name}.pt', mmap=True)[test_start:test_end:step_every]
im_dataset_name = "~/basic-33928allchars_centered_scaled_sorted_filtered_(128, 128)"
im_dataset = torch.load(f'./{im_dataset_name}.pt', mmap=True)[train_start:train_end:step_every]
im_dataset_test = torch.load(f'./{im_dataset_name}.pt', mmap=True)[test_start:test_end:step_every]
cff_train_tensor_dataset = TensorDataset(cff_dataset, im_dataset)
cff_train_dataloader = DataLoader(cff_train_tensor_dataset, batch_size=args['batch_size'], shuffle=True)
cff_pretrain_dataloader = DataLoader(cff_train_tensor_dataset, batch_size=args['pretrain_batch_size'], shuffle=True)
cff_pretrain_vit_encoder_dataloader = DataLoader(cff_train_tensor_dataset, batch_size=args['pretrain_vit_encoder_batch_size'], shuffle=True)
cff_posttrain_dataloader = DataLoader(cff_train_tensor_dataset, batch_size=args['post_train_batch_size'], shuffle=True)
cff_test_tensor_dataset = TensorDataset(cff_dataset_test, im_dataset_test)
cff_test_dataloader = DataLoader(cff_test_tensor_dataset, batch_size=args['batch_size'], shuffle=True)

In [None]:
print("\nPost-training model...\n")

kl_loss = torch.nn.functional.kl_div

@torch.no_grad()
def value_fn(image_gt, output_tokens):
    '''
    image_gt: the ground truth image
    output_tokens: the model's predicted output
    '''
    sequence = output_tokens.cpu().detach().numpy().flatten()
    torch.cuda.empty_cache()
    toks = [tokenizer.reverse_map(tk.item(), use_int=True) for tk in sequence]
    toks = [tok for tok in toks if tok != '<PAD2>' and tok != '<PAD>']
    if cumulative:
        toks = numbers_first(make_non_cumulative(toks, tokenizer), tokenizer, return_string=False)
    else:
        toks = numbers_first(toks, tokenizer, return_string=False)
    viz = Visualizer(output_tokens)
    im_pixel_size = (128, 128)
    crop_factor = 1
    dpi = 1
    boundaries = (int((im_pixel_size[0] * (crop_factor * 100 / dpi - 1)) // 2), int((im_pixel_size[1] * (crop_factor * 100 / dpi - 1)) // 2))
    im_size_inches = ((im_pixel_size[0] * crop_factor) / dpi, (im_pixel_size[1] * crop_factor) / dpi)
    output_image = viz.draw(
        display=False,
        filename=None,
        return_image=True,
        center=False,
        im_size_inches=im_size_inches,
        bounds=(-300, 300),
        dpi=dpi
    )[None,:,:,0] / 255.0
    value = -kl_loss((image_gt + 1) / 2, output_image)
    return value

@torch.no_grad()
def advantage_fn(image_gt, output_tokens):
    '''
    The value of the current output tokens ("next state") minus the value of the previous output tokens ("current state")
    
    image_gt: the ground truth image
    output_tokens: the model's predicted output
    '''
    return value_fn(image_gt, output_tokens) - value_fn(image_gt, output_tokens[:,:-7])


In [None]:
src = torch.zeros((args['post_train_batch_size'], 0)).to(device)
for epoch in range(args['post_train_epochs']):
    model.train()
    optimizer.zero_grad()
    total_loss = 0
    last_loss = 0
    train_batches = (max_len*(num_glyphs // step_every) // args['post_train_batch_size']) + 1
    for idx, (X, im) in enumerate(tqdm(cff_posttrain_dataloader, total=train_batches)):
        if idx >= train_batches:
            break
        inputs = X.to(device, dtype=torch.int32)
        im = im.to(dtype=args['data_type'], device=device).unsqueeze(1) / 127.5 - 1.0

        # output tokens from the current model
        out_tokens = model.decode(im, None, decode_instr)[0].cpu().detach().numpy().flatten()
        
        # token distributions
        in_tokens = out_tokens[:,:-7] # note: SOS token is prepended in forward()
        dist_new = model(im, in_tokens)
        with torch.no_grad():
            dist_original = original_model(im, in_tokens)
        
        adv = torch.Tensor([advantage_fn(im, out_tokens[:,:i*7]) for i in range(1, out_tokens.shape[1]//7+1)]).repeat_interleave(7, dim=1) # (batch_size=1, seq_len=5040)
        prob_new = torch.gather(dist_new, dim=-1, index=out_tokens.unsqueeze(1)).squeeze(1)
        prob_original = torch.gather(dist_original, dim=-1, index=out_tokens.unsqueeze(1)).squeeze(1)
        rel_prob = prob_new / prob_original # (batch_size=1, seq_len)
        eps = 0.2
        loss = torch.minimum(rel_prob * adv, torch.clip(rel_prob, 1-eps, 1+eps) * adv)

        total_loss += loss.item() * X.shape[0]
        loss.backward()
        torch.cuda.empty_cache()

        if (idx+1) % 1 == 0 or idx == train_batches-1:
            if args['gradient_clip']:
                torch.nn.utils.clip_grad_value_(model.parameters(), args['gradient_clip_val'])
            optimizer.step()
            optimizer.zero_grad()
            if args['post_train_use_scheduler']:
                scheduler.step()
            diff = total_loss - last_loss
            last_loss = total_loss

        if args['use_wandb']:
            if (idx+1) % 100 == 0 or (idx == train_batches-1 and (epoch+1) % args['sample_every'] == 0):
                goal_image, img_arr = decode(epoch, idx)
                wandb.log({
                    "posttrain_goal_image": goal_image,
                    "posttrain_images": img_arr,
                    "posttrain_loss_step": diff / (args['batch_accumulate'] * args['post_train_batch_size']),
                    "posttrain_lr_step": args['post_train_lr'] if not args['post_train_use_scheduler'] else scheduler.get_last_lr()[0],
                })
            elif (idx+1) % 1 == 0:
                wandb.log({
                    "posttrain_loss_step": diff / (args['batch_accumulate'] * args['post_train_batch_size']),
                    "posttrain_lr_step": args['post_train_lr'] if not args['post_train_use_scheduler'] else scheduler.get_last_lr()[0],
                })
    train_loss = total_loss / (min(train_batches, idx+1)*args['batch_size'])
    
    model.eval()
    total_loss = 0
    test_batches = 25
    true_positives = 0
    false_positives = 0
    true_negatives = 0
    false_negatives = 0
    with torch.no_grad():
        for idx, (X, im) in enumerate(tqdm(cff_test_dataloader, total=test_batches)):
            if idx >= test_batches:
                break
            inputs = X.to(device, dtype=torch.int32)
            im = im.to(dtype=args['data_type'], device=device).unsqueeze(1) / 127.5 - 1.0
            out = model(im, inputs[:,:-7]) # Use only output tokens before this truth term

            # loss = loss_fn(out.permute(0, 2, 1), inputs.long()) / X.shape[0]
            loss = numeric_mse_loss(out, inputs) / X.shape[0]
            
            total_loss += loss.item() * X.shape[0]
            torch.cuda.empty_cache()

            guesses = out.permute(0, 2, 1).argmax(dim=1)
            truths = inputs
            true_positives += ((guesses == truths) * (truths != tokenizer[pad_token])).sum()
            false_positives += ((guesses != truths) * (truths == tokenizer[pad_token])).sum()
            true_negatives += ((guesses == truths) * (truths == tokenizer[pad_token])).sum()
            false_negatives += ((guesses != truths) * (truths != tokenizer[pad_token])).sum()
        
        test_loss = total_loss / (min(test_batches, idx+1)*args['batch_size'])
        acc, pre, rec, f1 = PerformanceMetrics.all_metrics(
            tp=true_positives,
            fp=false_positives,
            tn=true_negatives,
            fn=false_negatives
        )

        print(f"Epoch {epoch+1}/{args['epochs']} completed. Train Loss = {train_loss_list[-1]};  Test Loss: {test_loss_list[-1]}")

        if args['use_wandb']:
            wandb.log({
                "posttrain_loss": train_loss,
                "posttrain_test_loss": test_loss,
                # "test_accuracy": acc,
                # "test_precision": pre,
                # "test_recall": rec,
                # "test_f1": f1,
                "lr": args['lr'] if not args['use_scheduler'] else scheduler.get_last_lr()[0],
            })

    # if (epoch+1) % 100 == 0 or epoch+1 == args['epochs']:
    # if max_len > 100:
    #     torch.save(model, f'models/transformer-{dataset_name}-{epoch+1}.pkl')
