In [None]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt

from moment.utils.config import Config
from moment.utils.utils import parse_config
from moment.data.generate_synthetic_data import SyntheticDataset
from moment.models.base import BaseModel
from moment.models.moment import MOMENTNT

### Defaults

In [None]:
DEFAULT_CONFIG_PATH = "../../configs/default.yaml"
GPU_ID = 0
run_name = "fearless-planet-52" # "avid-moon-55" "proud-dust-41" "curious-blaze-53"

In [None]:
# with open('/home/extra_scratch/XXXX-2/moment_checkpoints/avid-moon-55/MOMENT_checkpoint_5000.pth', 'rb') as f:
#     checkpoint = torch.load(f)
checkpoint = BaseModel.load_pretrained_weights(run_name=run_name, opt_steps=None)

config = Config(config_file_path=DEFAULT_CONFIG_PATH, default_config_file_path=DEFAULT_CONFIG_PATH).parse()
config['device'] = GPU_ID if torch.cuda.is_available() else 'cpu'

args = parse_config(config)
model = MOMENT(configs=args)
model.load_state_dict(checkpoint["model_state_dict"])

In [None]:
prompt_embeds = model.initialize_soft_prompt()

In [None]:
# Make sure to freeze Tranformers model
for name, param in model.named_parameters():
    if "soft" not in name:
        param.requires_grad = False

In [None]:
from moment.data.generate_synthetic_data import SyntheticDataset

In [None]:
synthetic_dataset = SyntheticDataset(n_samples=1024, freq=1, freq_range=(1, 32), 
                                     noise_mean=0., noise_std=0.1, random_seed=13)

y, c = synthetic_dataset.gen_sinusoids_with_varying_freq()
n_samples = synthetic_dataset.n_samples
seq_len = synthetic_dataset.seq_len

In [None]:
y = y.to(args.device)
model = model.to(args.device)
input_mask = torch.ones((n_samples, seq_len)).to(args.device)

model.eval()

In [None]:
with torch.no_grad():
    outputs = model.reconstruct(x_enc=y, input_mask=input_mask, prompt_embeds=prompt_embeds)
reconstruction = outputs.reconstruction.detach().cpu().numpy()

In [None]:
idx = np.random.randint(0, n_samples)
plt.plot(y[idx, :].squeeze().detach().cpu().numpy(), label="True")
plt.plot(reconstruction[idx, :].squeeze(), label="Reconstruction")
plt.legend()

In [None]:
# Taken from XXXX

import os
from pathlib import Path
import torch
import torch.nn as nn

class MOMENTPromptTuning:
    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: str,
        soft_prompt_path: str = None,
        n_tokens: int = None,
        initialize_from_vocab: bool = True,
        random_range: float = 0.5,
        **kwargs,
    ):
        model = super().from_pretrained(pretrained_model_name_or_path, **kwargs)

        # Make sure to freeze Tranformers model
        for param in model.parameters():
            param.requires_grad = False

        if soft_prompt_path is not None:
            model.set_soft_prompt_embeds(soft_prompt_path)
        elif n_tokens is not None:
            print("Initializing soft prompt...")
            model.initialize_soft_prompt(
                n_tokens=n_tokens,
                initialize_from_vocab=initialize_from_vocab,
                random_range=random_range,
            )

        return model

    def set_soft_prompt_embeds(
        self,
        soft_prompt_path: str,
    ) -> None:
        """
        Args:
            soft_prompt_path: torch soft prompt file path

        """
        self.soft_prompt = torch.load(
            soft_prompt_path, map_location=torch.device("cpu")
        )
        self.n_tokens = self.soft_prompt.num_embeddings
        print(f"Set soft prompt! (n_tokens: {self.n_tokens})")

    def initialize_soft_prompt(
        self,
        n_tokens: int = 20,
        random_range: float = 0.5,
    ) -> None:
        self.n_tokens = n_tokens
        
        init_prompt_value = torch.FloatTensor(2, 10).uniform_(
            -random_range, random_range)
        self.soft_prompt = nn.Embedding(n_tokens, self.config.n_embd)
        # Initialize weight
        self.soft_prompt.weight = nn.parameter.Parameter(init_prompt_value)

    def _cat_learned_embedding_to_input(self, input_ids) -> torch.Tensor:
        inputs_embeds = self.transformer.wte(input_ids)

        if len(list(inputs_embeds.shape)) == 2:
            inputs_embeds = inputs_embeds.unsqueeze(0)

        # [batch_size, n_tokens, n_embd]
        learned_embeds = self.soft_prompt.weight.repeat(inputs_embeds.size(0), 1, 1)

        inputs_embeds = torch.cat([learned_embeds, inputs_embeds], dim=1)

        return inputs_embeds

    def _extend_labels(self, labels, ignore_index=-100) -> torch.Tensor:
        if len(list(labels.shape)) == 1:
            labels = labels.unsqueeze(0)

        n_batches = labels.shape[0]
        return torch.cat(
            [
                torch.full((n_batches, self.n_tokens), ignore_index).to(self.device),
                labels,
            ],
            dim=1,
        )

    def _extend_attention_mask(self, attention_mask):

        if len(list(attention_mask.shape)) == 1:
            attention_mask = attention_mask.unsqueeze(0)

        n_batches = attention_mask.shape[0]
        return torch.cat(
            [torch.full((n_batches, self.n_tokens), 1).to(self.device), attention_mask],
            dim=1,
        )

    def save_soft_prompt(self, path: str, filename: str = "soft_prompt.model"):
        Path(path).mkdir(parents=True, exist_ok=True)
        torch.save(self.soft_prompt, os.path.join(path, filename))
        # print(f"Saved soft prompt: {os.path.join(path, filename)}")

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        use_cache=None,
        return_dict=None,
    ):
        if input_ids is not None:
            inputs_embeds = self._cat_learned_embedding_to_input(input_ids).to(
                self.device
            )

        if labels is not None:
            labels = self._extend_labels(labels).to(self.device)

        if attention_mask is not None:
            attention_mask = self._extend_attention_mask(attention_mask).to(self.device)

        # Drop most of the args for now
        return super().forward(
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            return_dict=return_dict,
        )