# Debug Test Script EnergyDiff

In [None]:
%matplotlib inline
from pathlib import Path

import numpy as np
import torch
import pytorch_lightning as pl
from matplotlib import pyplot as plt
import pandas as pd

from src.opensynth.data_modules.lcl_data_module import LCLDataModule, LCLData
from src.opensynth.models.energydiff import diffusion, model

In [None]:
class LCLDataModuleWithValidation(LCLDataModule):
    def __init__(self, data_path, stats_path, batch_size=32, n_samples=1000, outlier_path=None, n_val_samples=200):
        super().__init__(data_path, stats_path, batch_size, n_samples, outlier_path=outlier_path)
        self.n_val_samples = n_val_samples

    def setup(self, stage=None):
        super().setup(stage)
        self.val_dataset = LCLData(
            data_path=self.data_path,
            stats_path=self.stats_path,
            n_samples=self.n_val_samples,
            outlier_path=self.outlier_path,
        )
        
    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset,
            self.batch_size,
            drop_last=True,
            shuffle=False,
        )
        

In [None]:
# prep data
data_path = Path("data/processed/historical/train/lcl_data.csv")
stats_path = Path("data/processed/historical/train/mean_std.csv")
outlier_path = Path("data/processed/historical/train/outliers.csv")

dm = LCLDataModuleWithValidation(
    data_path=data_path,
    stats_path=stats_path,
    batch_size=200,
    n_samples=2000,
)
dm.setup()

In [None]:
# fake data
def get_noisy_sin(n_samples):
    x = np.linspace(0, 3/2 * np.pi, 48)
    samples = []
    for _ in range(n_samples):
        phase = np.random.uniform(0, 2 * np.pi)
        y = np.sin(x + phase) + np.random.normal(0, 0.01, x.shape)
        samples.append(y)
    return {
        'kwh': torch.from_numpy(np.array(samples)).float(), 
        'features': torch.from_numpy(x).float()
    }
    
class SinDataset(torch.utils.data.Dataset):
    def __init__(self, n_samples=1000):
        self.data = get_noisy_sin(n_samples)
    
    def __len__(self):
        return len(self.data['kwh'])
    
    def __getitem__(self, idx):
        return {
            'kwh': self.data['kwh'][idx],
            'features': self.data['features']
        }

class SinDataModule(pl.LightningDataModule):
    def __init__(self, n_samples=1000, batch_size=32, n_val_samples=0):
        super().__init__()
        self.n_samples = n_samples
        self.n_val_samples = n_val_samples
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = SinDataset(self.n_samples)
        self.val_dataset = SinDataset(self.n_val_samples)

    def train_dataloader(self):
        dataset = self.train_dataset
        return torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        dataset = self.val_dataset
        return torch.utils.data.DataLoader(dataset, batch_size=self.batch_size)


In [None]:
# dm = SinDataModule(n_samples=1000, batch_size=200, n_val_samples=600)
# dm.setup()

In [None]:
# prep model
df_model = diffusion.PLDiffusion1D(
    dim_base=128,
    dim_in=1,
    num_attn_head=4,
    num_decoder_layer=12,
    dim_feedforward=512,
    dropout=0.1,
    learn_variance=False,
    num_timestep=500,
    model_mean_type=diffusion.ModelMeanType.V,
    model_variance_type=diffusion.ModelVarianceType.FIXED_SMALL,
    loss_type=diffusion.LossType.MSE,
    beta_schedule_type=diffusion.BetaScheduleType.COSINE,
    lr=1e-3,
    ema_update_every=1,
    ema_decay=0.99,
)
trainer = pl.Trainer(
    gradient_clip_val=1.0,
    gradient_clip_algorithm="norm",
    max_epochs=500,
)

In [None]:
# df_model = diffusion.PLDiffusion1D.load_from_checkpoint('lightning_logs/version_38/checkpoints/epoch=1-step=20.ckpt',)

In [None]:
# training
trainer.fit(df_model, dm,
            ckpt_path='lightning_logs/version_38/checkpoints/epoch=1-step=20.ckpt', # optional. ALSO resumes the model, along with training state
            )

In [None]:
log_dir = trainer.logger.log_dir
metrics = pd.read_csv(f"{log_dir}/metrics.csv")
epoch_train_loss = metrics['train_loss_epoch'].dropna().values
epoch_val_loss = metrics['val_loss'].dropna().values

In [None]:
plt.plot(epoch_train_loss, label='Training Loss')
plt.plot(epoch_val_loss, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# sample
ema_df_model = df_model.ema.ema_model # GaussianDiffusion1D
ans_samples = ema_df_model.ancestral_sample(50, 50, 48, 1)
dpm_samples = ema_df_model.dpm_solver_sample(50, 50, 100, (48, 1))
true_samples = dm.dataset[0:50]['kwh']

In [None]:
plt.plot(dpm_samples.mean(dim=0).cpu().numpy(), label='DPM Solver')
plt.plot(ans_samples.mean(dim=0).cpu().numpy(), label='Ancestral Sampling')
plt.plot(true_samples.mean(dim=0).cpu().numpy(), label='True')
plt.legend()
plt.show()

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

for idx in range(min(ans_samples.shape[0], 25)):
    axs[0].plot(ans_samples[idx].detach().cpu().numpy(), linestyle='-.', label="ancestral" if idx == 0 else None)
axs[0].set_title('Ancestral Sampling')
axs[0].legend()

for idx in range(min(dpm_samples.shape[0], 25)):
    axs[1].plot(dpm_samples[idx].detach().cpu().numpy(), linestyle=':', label="dpm" if idx == 0 else None)
axs[1].set_title('DPM Solver')
axs[1].legend()

for idx in range(min(true_samples.shape[0], 25)):
    axs[2].plot(true_samples[idx].detach().cpu().numpy(), label="true" if idx == 0 else None)
axs[2].set_title('True Samples')
axs[2].legend()

plt.tight_layout()
plt.show()
pass