In [None]:
import torch
from torch import nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import torch.nn.functional as F
from torch.utils.data import DataLoader
import gc
import numpy as np
import itertools
import math

from sonar.inference_pipelines.text import EmbeddingToTextModelPipeline
from sentence_transformers import SentenceTransformer

import wandb
import traceback

from transformers.utils import logging
logging.set_verbosity_error()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
d_emb = 1024

pad = torch.from_numpy(np.load('/LCM/data/pad.npy')).to(device)

cos = nn.CosineSimilarity(dim=-1)
cossim = nn.CosineEmbeddingLoss()
mse = nn.MSELoss()

def criterion(output, target):
    
    mask = ~(target == pad).all(dim=-1) #[batch, seq]
    target = target[mask]
    output = output[mask]
    
    cossim_ = cossim(output, target, torch.full((1,), 1).to(device))
    mse_ = mse(output, target)
    return [cossim_, mse_]

In [None]:
def validation(epoch, cpt, model, valloader, val_list):
    cossim_loss = 0; mse_loss = 0
    model.eval()
    with torch.no_grad():
        for src in valloader:
            src = src.to(device)
            outputs = model.predict_next_sentence(src[:, :-1])
            [cossim_, mse_] = criterion(outputs, src[:, 1:])
            cossim_loss += cossim_.item(); mse_loss += mse_.item()
        cossim_loss = cossim_loss/len(valloader)
        mse_loss = mse_loss/len(valloader)
        val_list.append([cossim_loss, mse_loss])
        print('Epoch', epoch+1, 'Part', cpt, "Cossim", cossim_loss, "MSE", mse_loss)
        wandb.log({'Cossim': cossim_loss, 'MSE': mse_loss})
    model.train()
    return val_list, cpt+1


def autoregr_infer(model, prompts):
    seq = 20-1
    list_autoregr = []
    model.eval()
    with torch.no_grad():
        prompts = prompts.unsqueeze(0)
        for src in prompts:
            src = src.to(device)
            autoregr = torch.cat((src, model.predict_next_sentence(src)), dim=1)
            for i in range(seq-1):
                outputs = model.predict_next_sentence(autoregr)[:, -1].unsqueeze(1)
                autoregr = torch.cat((autoregr, outputs), dim=1)
            list_autoregr.append(autoregr[:, 1:])
    return torch.cat(list_autoregr, dim=0)


def calculate_score(output, targets):
    seq = 20-1

    score_one_sum = 0
    score_sum_pad = torch.zeros(seq).to(device)
    pad_nbr_sum = torch.zeros(seq).to(device)

    for batch, target in enumerate(targets):
        out = output[batch][:len(target)].to(device)
        
        score = cos(out, target)

        score_one_sum += score.mean()
        score_sum_pad += F.pad(score, (0, seq - len(score)))
        pad_nbr_sum += F.pad(torch.ones(len(score)), (0, seq - len(score))).to(device)

    paragraphed_score = score_sum_pad/pad_nbr_sum
    paragraphed_score = [round(elem.item(), 2) for elem in paragraphed_score]
    final_score = score_one_sum.item()/len(targets)
    return final_score, paragraphed_score


def test(model, test_data, config):
    
    sonarprompt, sonaroutput, jasperoutput = test_data
    
    output_autoregr = autoregr_infer(model, sonarprompt)
    final_sonar, paragraphed_sonar = calculate_score(output_autoregr, sonaroutput)

    print('Sonar score:', final_sonar, paragraphed_sonar)
    wandb.log({'Final sonar': final_sonar, 'Paragraphed sonar': paragraphed_sonar})

    vec2text_model = EmbeddingToTextModelPipeline(decoder="text_sonar_basic_decoder", tokenizer="text_sonar_basic_decoder", device=device)

    text_autoregr = []
    for data in output_autoregr[:10]:
        text_autoregr.append(vec2text_model.predict(data, target_lang="eng_Latn", max_seq_len=64))
    
    torch.cuda.empty_cache(); del vec2text_model, sonarprompt, sonaroutput, output_autoregr; gc.collect()


    jasper = SentenceTransformer("infgrad/jasper_en_vision_language_v1",
        trust_remote_code=True,
        device=device,
        model_kwargs={"torch_dtype":  torch.bfloat16 if device == torch.device('cuda') else torch.float32},
    ); jasper.max_seq_length = 1024

    jasper_emb = []
    for data in text_autoregr:
        jasper_emb.append(torch.from_numpy(jasper.encode(data)).to(device))

    torch.cuda.empty_cache(); del jasper; gc.collect()
    
    
    final_jasper, paragraphed_jasper = calculate_score(jasper_emb, jasperoutput)
    
    torch.cuda.empty_cache(); del jasper_emb, jasperoutput; gc.collect()

    if final_sonar > 0.4 or final_jasper > 0.4:
        torch.save([model.state_dict(), config], '/LCM/Base_LCM_' + str(round(final_sonar, 2)) + '_' + str(round(final_jasper, 2)) + '.pth')

    print('Jasper score:', final_jasper, paragraphed_jasper)
    wandb.log({'Final jasper': final_jasper, 'Paragraphed jasper': paragraphed_jasper})
    
    print('Final score:', (final_sonar+final_jasper)/2)
    wandb.log({'Final score': (final_sonar+final_jasper)/2})

In [None]:
from lcm.models.two_tower_diffusion_lcm.archs import base_lcm_max
from lcm.models.two_tower_diffusion_lcm.builder import BaseLCModelBuilder

def objective(config, data):
    
    traindata, valloader, test_data = data
    
    #model = BaseLCModelBuilder(base_lcm_tuner(config), device=device).build_model(config.sonar_normalizer_name)
    model = BaseLCModelBuilder(base_lcm_max(config), device=device).build_model(config.sonar_normalizer_name)
    #model = BaseLCModelBuilder(base_lcm_max(config), device=device).build_model()
    model_save = model
    
    optimizer = AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

    if config.scheduler:
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=config.T_0)
    else:
        scheduler = None

    val_list=[]
    cpt = 1
    loss_type=1 #COSSIM:0   MSE:1
    patience = 3

    checker = 6400/config.batch_size

    print('')
    
    for epoch in itertools.cycle(range(10)):
        if epoch!=0:
                torch.cuda.empty_cache(); del epoch_traindata, trainloader; gc.collect()
        if epoch%10==0:
            epoch_traindata = traindata
        else:
            epoch_traindata = torch.from_numpy(np.load('/LCM/data/'+str(epoch)+'00k.npy'))
            
        trainloader = DataLoader(epoch_traindata, config.batch_size)
        
        model.train()

        for i, src in enumerate(trainloader):
            optimizer.zero_grad()
            src = src.to(device)
            outputs = model.predict_next_sentence(src[:, :-1])
            loss = criterion(outputs, src[:, 1:])[loss_type]
            loss.backward()
            
            if config.clip_grad:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            
            optimizer.step()
            
            if scheduler is not None:
                scheduler.step()
            
            if i!=0 and i%checker==0:
                val_list, cpt = validation(epoch, cpt, model, valloader, val_list)
                
                if math.isnan(val_list[-1][1]):
                    if len(val_list) > 1:
                        test(model_save, test_data, config)
                    else:
                        wandb.log({'Final score': -1})
                    break
                
                if val_list[0][1]> 0.8: wandb.log({'Cossim': val_list[-1][0], 'MSE': val_list[-1][1]}); test(model_save, test_data, config); break
                if cpt==5 and val_list[-1][1]> 0.01: wandb.log({'Cossim': val_list[-1][0], 'MSE': val_list[-1][1]}); test(model_save, test_data, config); break
                if cpt==30 and val_list[-1][1]> 5e-5: wandb.log({'Cossim': val_list[-1][0], 'MSE': val_list[-1][1]}); test(model_save, test_data, config); break
                
                #early stopping code:
                if len(val_list) >= 2:
                    loss_list = [x[loss_type] for x in val_list]
                    if loss_list[-1] < min(loss_list[:-1]):#it performs better, we save the model
                        model_save = model
                    elif (len(loss_list) - loss_list.index(min(loss_list))) > patience: #no better model in the last epochs
                        print('Best: Part', cpt-1-patience, 'Cossim', val_list[-1-patience][0], 'MSE', val_list[-1-patience][1])
                        wandb.log({'Cossim': val_list[-1-patience][0], 'MSE': val_list[-1-patience][1]})
                        torch.cuda.empty_cache(); del src, outputs, model, optimizer, valloader, trainloader, epoch_traindata, traindata, loss; gc.collect()
                        
                        #compute Final score
                        test(model_save, test_data, config)
                        break
        else:
            continue
        break

In [None]:
sweep_conf = {
    "method": "bayes",
    "metric": {"goal": "maximize", "name": "Final score"},
    "parameters": {
        "batch_size": {"values": [8, 16, 32, 64]},
        "scheduler": {"values": [True, False]},
        "T_0": {"min": 10, "max": 20000},
        "lr": {"distribution": "log_uniform_values", "min": 0.00001, "max": 0.1},
        "weight_decay": {"distribution": "log_uniform_values", "min": 0.00001, "max": 0.1},
        "clip_grad": {"values": [True, False]},
        
        #"model_dim": {"values": [512, 1024, 2048]},
        "model_dim": {"values": [2048]},
        #"num_attn_heads": {"values": [16]},
        
        #"sonar_normalizer_name": {"values": ["dummy_sonar_normalizer", "layernorm", None]},

        #"frontend_dropout_p": {"min": 0.0, "max": 0.5},
        #"frontend_pre_linear_init_fn": {"values": ['xavier', 'sonar', 'zero', 'trunc_normal', 'kaiming_uniform', 'none']},
        #"frontend_scale_embeddings": {"values": [True, False]},
        #"frontend_weight_normalization": {"values": [True, False]},

        #"lcm_final_dropout_p": {"min": 0.0, "max": 0.5},
        #"lcm_attention_dropout_p": {"min": 0.0, "max": 0.5},
        #"lcm_dropout_p": {"min": 0.0, "max": 0.5},
        #"lcm_ffn_inner_dim": {"values": [1, 2, 4]},
        "lcm_ffn_inner_dim": {"values": [2]},
        #"lcm_num_layers": {"values": [2, 8, 16, 24, 32]},
        #"lcm_num_layers": {"values": [2, 8, 14, 24, 32]},
        "lcm_num_layers": {"values": [18]},
        #"lcm_pos_embedding_style": {"values": ["rope", "sine", "learned", "none"]},
        #"lcm_use_swiglu": {"values": [True, False]},
        #"lcm_ffn_inner_activation_name": {"values": ["relu", "tanh", "elu", "leaky_relu", "prelu", "selu", "gelu", "silu", "softsign", "sigmoid", "hardsigmoid", None]},
        #"lcm_ffn_inner_activation_name": {"values": ["relu", "tanh", "elu", "leaky_relu", "selu", "gelu", "silu", "softsign", "sigmoid", "hardsigmoid", None]},
        #"lcm_layer_normalization_style": {"values": ["standard", "fp32", "rms", "unit"]},
        #"lcm_norm_order_style": {"values": ['pre', 'post', 'normformer']},
        #"lcm_final_norm_order_style": {"values": ['pre', 'post', 'normformer']},
        #"lcm_enable_qk_layernorm": {"values": [True, False]},
        #"lcm_mha_qkv_weight_normalization": {"values": [True, False]},
        #"lcm_mha_output_weight_normalization": {"values": [True, False]},
        #"lcm_mha_output_proj_bias": {"values": [True, False]},
        #"lcm_attention_output_init_fn": {"values": ['xavier', 'sonar', 'zero', 'trunc_normal', 'kaiming_uniform', 'none']},

        #"postnet_dropout_p": {"min": 0.0, "max": 0.5},
        #"postnet_linear_init_fn": {"values": ['xavier', 'sonar', 'zero', 'trunc_normal', 'kaiming_uniform', 'none']},
        #"postnet_weight_normalization": {"values": [True, False]},
        #"postnet_layer_normalization_style": {"values": ["standard", "fp32", "rms", "unit"]},
        #"postnet_activation_name": {"values": ["relu", "tanh", "elu", "leaky_relu", "prelu", "selu", "gelu", "silu", "softsign", "sigmoid", "hardsigmoid", None]},
    },
}

np_data = np.load('/LCM/data/100k_1k_0.npz')
traindata = torch.from_numpy(np_data['train'])
valloader = DataLoader(torch.from_numpy(np_data['val']), batch_size=64)
test_data = torch.load('/LCM/data/test_sonarprompt_sonaroutput_jasperoutput.pth')
torch.cuda.empty_cache(); del np_data; gc.collect()


#os.environ["WANDB_SILENT"] = "true"

def main():
    wandb.init()
    try:
        objective(wandb.config, [traindata, valloader, test_data])
    except Exception as e:
        print(e)
        traceback.print_exc()
        wandb.log({'Final score': -1})
    
sweep_id = wandb.sweep(sweep_conf, project="LCM")

wandb.agent(sweep_id, function=main)