In [None]:
# jupyter nbconvert --to script fpt_training.ipynb
%load_ext autoreload
%autoreload 2

import sys
import time

import torch
from torch.utils.data import DataLoader
import wandb
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding

from moment.utils.config import Config
from moment.utils.utils import parse_config, control_randomness
from moment.data.image_datasets import CIFAR10GrayDataset, CIFAR10Dataset, MNISTDataset
from moment.data.bit_datasets import BitMemoryDataset, BitXORDataset
from moment.data.nlp_datasets import NLPDataset
from moment.models.fpt import FrozenPretrainedTransformer
from moment.scripts.development.fpt_trainer import Trainer

## TODO: Randomness needs to be controlled

In [None]:
GPU_ID = 2
DEFAULT_CONFIG_PATH = "../../configs/default.yaml"

# FPT model config
config_file_path = "../../configs/frozen_pretrained_transformer/moment_large_fpt.yaml"
# config_file_path = "../../configs/frozen_pretrained_transformer/flant5_large_fpt.yaml"
# config_file_path = "../../configs/frozen_pretrained_transformer/gpt2_med_fpt.yaml"

config = Config(config_file_path=config_file_path, 
                default_config_file_path=DEFAULT_CONFIG_PATH).parse()
config['device'] = GPU_ID if torch.cuda.is_available() else 'cpu'
args = parse_config(config)
control_randomness(args.random_seed)

In [None]:
FPT = FrozenPretrainedTransformer(configs=args)
FPT = FPT.to(args.device)
if sys.version_info <= (3, 10):
    print('Compiling FPT model...')
    FPT = torch.compile(FPT)
FPT

In [None]:
# dataset = MNISTDataset(batch_size=args.batch_size, patch_size=4, device=f'cuda:{GPU_ID}')
dataset = CIFAR10Dataset(batch_size=args.batch_size, patch_size=4, device=f'cuda:{GPU_ID}')
# dataset = BitMemoryDataset(n=1000, num_patterns=5, device=f'cuda:{GPU_ID}')
# dataset = NLPDataset(
#     dataset_name='imdb',
#     model_name=args.model_name,
#     batch_size=args.batch_size,
#     device=f'cuda:{GPU_ID}'
# )

x, y = dataset.get_batch(batch_size=args.batch_size)
x.shape, y.shape

In [None]:
ce_loss = torch.nn.CrossEntropyLoss()

# original - CIFAR10, MNIST
def loss_fn(out, y, x=None):
    out = out[:, 0]
    return ce_loss(out, y)

def accuracy_fn(preds, true, x=None):
    preds = preds[:, 0].argmax(-1)
    return (preds == true).mean()

# # IMDB
# def loss_fn(out, y, x=None):
#     return ce_loss(out, y)

# def accuracy_fn(preds, true, x=None):
#     pred_labels = np.argmax(preds, axis=1)
#     correct = np.sum(pred_labels == true)
#     accuracy = correct / true.shape[0]
#     return accuracy

# # Bit-Memory
# def loss_fn(out, y, x=None):
#     out = torch.reshape(out, (-1, 1000, 2))
#     ids = torch.zeros(y.shape).to(device=y.device).long()
#     ids[y < 0], ids[y > 0] = 0, 1
#     out, ids = torch.reshape(out, (-1, 2)), torch.reshape(ids, (-1,))
#     return ce_loss(out, ids)

# def accuracy_fn(preds, true, x=None):
#     preds = preds.reshape(-1, 1000, 2).argmax(-1) * 2 - 1
#     return (np.sign(preds) == np.sign(true)).mean()

In [None]:
trainer = Trainer(
    FPT,
    dataset,
    loss_fn,
    accuracy_fn=accuracy_fn,
    steps_per_epoch=args.steps_per_epoch,
    test_steps_per_epoch=int(args.steps_per_epoch * 0.2),
    learning_rate=args.learning_rate,
    batch_size=args.batch_size,
    eval_batch_size=1,
    grad_accumulate=1,
)

wandb.init(
    project='Time-series Foundation Model',
    name='FPT - CIFAR10 - MOMENT',
    config=args,
    mode='disabled' if args.debug else 'run',
)

total_steps = 0
for i in range(args.num_epochs):
    trainer.train_epoch()
    total_steps += args.steps_per_epoch
    wandb.log({
        'Train Loss': trainer.diagnostics['Average Train Loss'],
        'Test Loss': trainer.diagnostics['Test Loss'],
        'Train Accuracy': trainer.diagnostics['Train Accuracy'],
        'Test Accuracy': trainer.diagnostics['Test Accuracy'],
        'Epoch': i,
        'Steps': total_steps,
    })
    print(
        f'Epoch {i+1}/{args.num_epochs} ' \
        f'Train Loss: {trainer.diagnostics["Average Train Loss"]:.3f} ' \
        f'Test Loss: {trainer.diagnostics["Test Loss"]:.3f} ' \
        f'Train Accuracy: {trainer.diagnostics["Train Accuracy"]:.3f} ' \
        f'Test Accuracy: {trainer.diagnostics["Test Accuracy"]:.3f}'
    )