In [1]:
import os

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.cluster import KMeans

from tqdm.notebook import tqdm 

from torch.utils.tensorboard import SummaryWriter

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch.optim.lr_scheduler as lr_scheduler

from models.enhanced_baseline import EnhancedBaseLineModel
from ddpm.ddpm import GaussianDiffusion1D
from evaluation.evaluation import vizual_comparison, plot_jsd_per_customer, plot_kde_samples, make_gif_from_images, mmd_histogram_per_customer

# Params

In [2]:
## Data
seq_len = 48
batch_size = 128
k = 15

#NN
device = 'cuda' if torch.cuda.is_available() else 'cpu'
latent_dim = 1000
cond_model = "mlp"
num_layers = 6
n_heads = 8
lr = 5e-5
decay_rate = 0.9
epochs = 2000
save_rate = 100

## DDPM
timesteps = 1000
beta_schedule = "cosine"
objective = "pred_noise"

## Logging
experiment_name = "enhanced_dev"
logging_dir = f"./logging/{experiment_name}/"

In [3]:
if not os.path.isdir(logging_dir):
    os.makedirs(logging_dir)
    os.makedirs(os.path.join(logging_dir, "viz/"))
    os.makedirs(os.path.join(logging_dir, "jsd/"))
    os.makedirs(os.path.join(logging_dir, "kde/"))
    os.makedirs(os.path.join(logging_dir, "tensorboard/"))
    os.makedirs(os.path.join(logging_dir, "weights/"))
    os.makedirs(os.path.join(logging_dir, "ts_sample/"))
    os.makedirs(os.path.join(logging_dir, "mmd/"))

# Utils

In [4]:
PREPROCESSED_DIR = "./preprocessing/data/customer_led_network_revolution/preprocessed/"

In [5]:
class MakeDATA(Dataset):
    def __init__(self, data, seq_len):
        data = np.asarray(data, dtype=np.float32)
        seq_data = []
        for i in range(len(data) - seq_len + 1):
            x = data[i : i + seq_len]
            seq_data.append(x)
        self.samples = np.asarray(seq_data, dtype=np.float32) 

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

In [6]:
def cluster(data, k):
    kmeans = KMeans(n_clusters=k)
    cluster_labels = kmeans.fit_predict(data.T)
    
    clustered_data = []
    for cluster in range(kmeans.n_clusters):
        cluster_data = data.iloc[:, cluster_labels == cluster].mean(axis=1)
        clustered_data.append(cluster_data)
        
    return pd.DataFrame(clustered_data).T

# Load Data

In [7]:
train = pd.read_csv(os.path.join(PREPROCESSED_DIR, "train.csv"))
val = pd.read_csv(os.path.join(PREPROCESSED_DIR, "val.csv"))
test = pd.read_csv(os.path.join(PREPROCESSED_DIR, "test.csv"))
cond_train = pd.read_csv(os.path.join(PREPROCESSED_DIR, "cond_train.csv"))
cond_val = pd.read_csv(os.path.join(PREPROCESSED_DIR, "cond_val.csv"))
cond_test = pd.read_csv(os.path.join(PREPROCESSED_DIR, "cond_test.csv"))

In [8]:
train.drop("Date and Time of capture", axis=1, inplace=True)
val.drop("Date and Time of capture", axis=1, inplace=True)
test.drop("Date and Time of capture", axis=1, inplace=True)

In [9]:
train = cluster(train, k)
test = cluster(test, k)
val = cluster(val, k)

found 0 physical cores < 1
  File "c:\Users\Arne\anaconda3\envs\hf_diff\lib\site-packages\joblib\externals\loky\backend\context.py", line 282, in _count_physical_cores
    raise ValueError(f"found {cpu_count_physical} physical cores < 1")


In [10]:
train.shape, val.shape, test.shape, cond_train.shape, cond_val.shape, cond_test.shape

((33984, 15), (6796, 15), (1700, 15), (33984, 27), (6796, 27), (1700, 27))

In [11]:
train_seq = np.asarray(MakeDATA(train, seq_len)).transpose(0, 2, 1)
cond_train_seq = np.asarray(MakeDATA(cond_train, seq_len)).transpose(0, 2, 1)

val_seq = np.asarray(MakeDATA(val, seq_len)).transpose(0, 2, 1)
cond_val_seq = np.asarray(MakeDATA(cond_val, seq_len)).transpose(0, 2, 1)

test_seq = np.asarray(MakeDATA(test, seq_len)).transpose(0, 2, 1)
cond_test_seq = np.asarray(MakeDATA(cond_test, seq_len)).transpose(0, 2, 1)

train_seq.shape, cond_train_seq.shape, test_seq.shape, cond_test_seq.shape

((33937, 15, 48), (33937, 27, 48), (1653, 15, 48), (1653, 27, 48))

In [12]:
train_dataset = TensorDataset(torch.from_numpy(train_seq), torch.from_numpy(cond_train_seq))
train_loader = DataLoader(train_dataset, batch_size, shuffle=False)

val_dataset = TensorDataset(torch.from_numpy(val_seq), torch.from_numpy(cond_val_seq))
val_loader = DataLoader(val_dataset, batch_size, shuffle=False)

test_dataset = TensorDataset(torch.from_numpy(test_seq), torch.from_numpy(cond_test_seq))
test_loader = DataLoader(test_dataset, batch_size, shuffle=False)

In [13]:
real_data_val, real_cond_data_val = next(iter(val_loader))

# Load Model

In [14]:
# def init_weights(m):
#     if isinstance(m, nn.Linear):
#         torch.nn.init.xavier_uniform_(m.weight)
#         if m.bias is not None:
#             torch.nn.init.zeros_(m.bias)
#     elif isinstance(m, nn.LSTM):
#         for name, param in m.named_parameters():
#             if 'weight' in name:
#                 torch.nn.init.xavier_uniform_(param.data)
#             elif 'bias' in name:
#                 torch.nn.init.zeros_(param.data)

In [15]:
model = EnhancedBaseLineModel(features=train_seq.shape[1], hidden_dim=latent_dim, cond_dim=cond_train_seq.shape[1], cond_model=cond_model, device=device, channels=train_seq.shape[1])

# model.apply(init_weights)

ddpm = GaussianDiffusion1D(model, seq_length = seq_len, timesteps = timesteps, objective = objective, loss_type = 'l2', beta_schedule = beta_schedule)
ddpm = ddpm.to(device)

optim = torch.optim.Adam(ddpm.parameters(), lr = lr)
scheduler = lr_scheduler.StepLR(optim, step_size=1000, gamma=0.9)

writer = SummaryWriter(os.path.join(logging_dir, "tensorboard/"))

# Train

In [16]:
#taken from https://medium.com/@heyamit10/exponential-moving-average-ema-in-pytorch-eb8b6f1718eb
class EMA:
    def __init__(self, model, decay):
        """
        Initialize EMA class to manage exponential moving average of model parameters.
        
        Args:
            model (torch.nn.Module): The model for which EMA will track parameters.
            decay (float): Decay rate, typically a value close to 1, e.g., 0.999.
        """
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}

        # Store initial parameters
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        """
        Update shadow parameters with exponential decay.
        """
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()

    def apply_shadow(self):
        """
        Apply shadow (EMA) parameters to model.
        """
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data.clone()
                param.data = self.shadow[name]

    def restore(self):
        """
        Restore original model parameters from backup.
        """
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.data = self.backup[name]

In [17]:
def save_model(model_dict, ema_shadow, optimizer_dict, filepath):
    checkpoint = {
        'model_state_dict': model_dict,
        'ema_state_dict': ema_shadow,
        'optimizer_state_dict': optimizer_dict,
    }
    torch.save(checkpoint, filepath)

In [18]:
def load_checkpoint(model, ema, optimizer, filepath):
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    if 'ema_state_dict' in checkpoint:
        for name, param in model.named_parameters():
            if param.requires_grad and name in checkpoint['ema_state_dict']:
                ema.shadow[name] = checkpoint['ema_state_dict'][name].clone()
    
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [19]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.
    
    for i, (data, cond_data) in enumerate(train_loader):
        data = data.to(device)
        
        cond_data = cond_data.float()
        cond_data = cond_data.to(device)
        
        optim.zero_grad()
        
        loss = ddpm(data, cond_data)
        with torch.autograd.set_detect_anomaly(True):
            loss.backward()
        
        #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
        
        # ################# DEBUG ########################
        # """
        # Exploding gradients: Norms become very large (e.g., >100 or even inf).
        # Vanishing gradients: Norms shrink toward zero.
        # """
        # total_norm = 0.0
        # for p in model.parameters():
        #     if p.grad is not None:
        #         param_norm = p.grad.data.norm(2)
        #         total_norm += param_norm.item() ** 2
        # total_norm = total_norm ** 0.5
        # tb_writer.add_scalar('Gradient Norm', total_norm, epoch_index * len(train_loader) + i)
        
        # if total_norm > 1e3 or torch.isnan(torch.tensor(total_norm)):
        #     print(f"High or NaN gradient norm at step {i}: {total_norm:.4f}")
            
        # ################# DEBUG ########################
        
        optim.step()
        ema.update()
        
        running_loss += loss.item()
        
        if i % 100 == 0:
            last_loss = running_loss / 1000 # loss per batch
            tb_x = epoch_index * len(train_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.
    return last_loss

In [20]:
ema = EMA(model, decay=decay_rate)

In [21]:
epoch_number = 0
best_val_loss = 1_000_000.
save_dir =  os.path.join(logging_dir, "weights/")

ema = EMA(model, decay=decay_rate)

for epoch in tqdm(range(1, epochs+1)):
    print('epoch {}:'.format(epoch_number + 1))
    
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)
    
    running_val_loss = 0.0
    
    ema.apply_shadow()
    model.eval()
    with torch.no_grad():
        for i, (val_data, val_cond) in enumerate(val_loader):
            val_data = val_data.to(device)
        
            val_cond = val_cond.float()
            val_cond = val_cond.to(device)
            val_loss = ddpm(val_data, val_cond)
            
            running_val_loss += val_loss
    
    avg_val_loss = running_val_loss / (i + 1)
    print('Loss train {} val {}'.format(avg_loss, avg_val_loss))
    
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_val_loss },
                    epoch_number + 1)
    
    writer.flush()
    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model_dict = model.state_dict()
        best_model_opt = optim.state_dict()
        best_model_shadow = ema.shadow
        
    
    if epoch_number % save_rate == 0:
        generated_sample = ddpm.sample(batch_size, real_cond_data_val.to(device))
        generated_sample = generated_sample.cpu().numpy()
        
        plot_kde_samples(generated_sample, real_data_val,show=False, fpath=os.path.join(logging_dir, "kde/", f"kde_epoch_{epoch_number}.png"), epoch=epoch_number)
        
        model_name = f"model_epoch_{epoch_number}_val_{avg_val_loss:.4f}.pth"
        save_model(model.state_dict(), ema.shadow, optim.state_dict(), os.path.join(save_dir, model_name))
    
    ema.restore()
    epoch_number += 1

  0%|          | 0/2000 [00:00<?, ?it/s]

epoch 1:


  result = _VF.lstm(


Loss train 1.3463373770713807 val 1.2723420858383179


sampling loop time step: 100%|██████████| 1000/1000 [00:41<00:00, 23.95it/s]


epoch 2:
Loss train 0.11254887318611145 val 1.298607349395752
epoch 3:
Loss train 0.08954696607589721 val 1.286897897720337
epoch 4:
Loss train 0.07518663370609284 val 1.429446816444397
epoch 5:
Loss train 0.05480870220065117 val 1.1906548738479614
epoch 6:
Loss train 0.0468406862616539 val 1.316590428352356
epoch 7:
Loss train 0.036303610414266586 val 1.2452750205993652
epoch 8:
Loss train 0.03073615598678589 val 1.0947567224502563
epoch 9:
Loss train 0.028694360673427583 val 1.4266196489334106
epoch 10:
Loss train 0.02436915910243988 val 0.8712239861488342
epoch 11:
Loss train 0.019972516730427742 val 0.8117837309837341
epoch 12:
Loss train 0.02157553370296955 val 2.108886241912842
epoch 13:
Loss train 0.017647766202688216 val 0.8306003212928772
epoch 14:
Loss train 0.01647813509404659 val 0.6700195074081421
epoch 15:
Loss train 0.015940967231988908 val 1.0412626266479492
epoch 16:
Loss train 0.01641804249584675 val 1.2835571765899658
epoch 17:
Loss train 0.014072835259139539 val 1.0

sampling loop time step: 100%|██████████| 1000/1000 [00:43<00:00, 23.20it/s]


epoch 102:
Loss train 0.003325168455019593 val 0.13092674314975739
epoch 103:
Loss train 0.0033609915636479853 val 0.13483203947544098
epoch 104:
Loss train 0.0032419493962079285 val 0.12412938475608826
epoch 105:
Loss train 0.003339411111548543 val 0.12416158616542816
epoch 106:
Loss train 0.003174264112487435 val 0.13317029178142548
epoch 107:
Loss train 0.0033208289500325917 val 0.12906573712825775
epoch 108:
Loss train 0.0034286395329982044 val 0.12547515332698822
epoch 109:
Loss train 0.003240148330107331 val 0.13292475044727325
epoch 110:
Loss train 0.0032334995539858935 val 0.12781357765197754
epoch 111:
Loss train 0.003269212333485484 val 0.13310512900352478
epoch 112:
Loss train 0.003302158256992698 val 0.13048326969146729
epoch 113:
Loss train 0.003169633448123932 val 0.13099713623523712
epoch 114:
Loss train 0.003125088242813945 val 0.1254284828901291
epoch 115:
Loss train 0.0032639253493398426 val 0.13477428257465363
epoch 116:
Loss train 0.003196832099929452 val 0.13134926

sampling loop time step: 100%|██████████| 1000/1000 [00:44<00:00, 22.44it/s]


epoch 202:
Loss train 0.003012100614607334 val 0.13072076439857483
epoch 203:
Loss train 0.0029199567921459673 val 0.13456125557422638
epoch 204:
Loss train 0.0030235277619212867 val 0.13888585567474365
epoch 205:
Loss train 0.002998426027595997 val 0.13081631064414978
epoch 206:
Loss train 0.0028383654933422805 val 0.1327795535326004
epoch 207:
Loss train 0.00287128964997828 val 0.12915174663066864
epoch 208:
Loss train 0.00283648444339633 val 0.13492359220981598
epoch 209:
Loss train 0.002935904048383236 val 0.12963756918907166
epoch 210:
Loss train 0.002960858955979347 val 0.13095463812351227
epoch 211:
Loss train 0.0029254936017096044 val 0.12851200997829437
epoch 212:
Loss train 0.002930965529754758 val 0.13247685134410858
epoch 213:
Loss train 0.0029447086434811354 val 0.13306021690368652
epoch 214:
Loss train 0.00293585417047143 val 0.13165943324565887
epoch 215:
Loss train 0.0028458692599087954 val 0.1285165250301361
epoch 216:
Loss train 0.0029023040682077408 val 0.13146722316

sampling loop time step: 100%|██████████| 1000/1000 [00:42<00:00, 23.38it/s]


epoch 302:
Loss train 0.002836126321926713 val 0.13615280389785767
epoch 303:
Loss train 0.002726882940158248 val 0.133378803730011
epoch 304:
Loss train 0.002822103740647435 val 0.13866525888442993
epoch 305:
Loss train 0.0027229377180337908 val 0.1384442299604416
epoch 306:
Loss train 0.002718561850488186 val 0.1359541416168213
epoch 307:
Loss train 0.002691067675128579 val 0.14046527445316315
epoch 308:
Loss train 0.0027388985697180033 val 0.14290377497673035
epoch 309:
Loss train 1.6332069532480091 val 0.134016752243042
epoch 310:
Loss train 0.002672427276149392 val 0.13807179033756256
epoch 311:
Loss train 0.0028018085695803168 val 0.13587996363639832
epoch 312:
Loss train 0.002656214131042361 val 0.13116583228111267
epoch 313:
Loss train 0.002662895732559264 val 0.13699206709861755
epoch 314:
Loss train 0.0027674139756709336 val 0.14225322008132935
epoch 315:
Loss train 0.002672420619055629 val 0.14050118625164032
epoch 316:
Loss train 0.00270183583162725 val 0.1346438229084015
e

sampling loop time step: 100%|██████████| 1000/1000 [00:41<00:00, 24.34it/s]


epoch 402:
Loss train 0.002596244404092431 val 0.14195367693901062
epoch 403:
Loss train 0.0024650017712265254 val 0.1410205066204071
epoch 404:
Loss train 0.002473275063559413 val 0.14689762890338898
epoch 405:
Loss train 0.002541146144270897 val 0.14273948967456818
epoch 406:
Loss train 0.0025177181623876094 val 0.13828735053539276
epoch 407:
Loss train 0.0025561963301151992 val 0.14835837483406067
epoch 408:
Loss train 0.002453950550407171 val 0.1449754685163498
epoch 409:
Loss train 0.002487912547774613 val 0.1398288607597351
epoch 410:
Loss train 0.002473792776465416 val 0.14071331918239594
epoch 411:
Loss train 0.0024648985732346773 val 0.13863202929496765
epoch 412:
Loss train 0.0025293112201616166 val 0.1451311707496643
epoch 413:
Loss train 0.0024540280252695083 val 0.140900656580925
epoch 414:
Loss train 0.0023843123503029346 val 0.1347668319940567
epoch 415:
Loss train 0.0023760896315798163 val 0.13788148760795593
epoch 416:
Loss train 0.0025597570072859524 val 0.14393338561

sampling loop time step: 100%|██████████| 1000/1000 [00:40<00:00, 24.42it/s]


epoch 502:
Loss train 0.0022654493041336537 val 0.13944584131240845
epoch 503:
Loss train 0.0023193620555102827 val 0.14306333661079407
epoch 504:
Loss train 0.0022076824894174934 val 0.1415918916463852
epoch 505:
Loss train 0.002264402598142624 val 0.14486399292945862
epoch 506:
Loss train 0.002302089028060436 val 0.13737307488918304
epoch 507:
Loss train 0.002247958027757704 val 0.14782369136810303
epoch 508:
Loss train 0.0022890255833044647 val 0.13814523816108704
epoch 509:
Loss train 0.002184304120019078 val 0.14578919112682343
epoch 510:
Loss train 0.002251549968495965 val 0.14321103692054749
epoch 511:
Loss train 0.002242999425157905 val 0.14044177532196045
epoch 512:
Loss train 0.0021890879552811386 val 0.14809174835681915
epoch 513:
Loss train 0.002211575559340417 val 0.14332067966461182
epoch 514:
Loss train 0.0023007832700386644 val 0.1493394821882248
epoch 515:
Loss train 0.0022483365619555117 val 0.14246854186058044
epoch 516:
Loss train 0.0022245914097875357 val 0.1414531

sampling loop time step: 100%|██████████| 1000/1000 [00:41<00:00, 24.32it/s]


epoch 602:
Loss train 0.0019281487921252846 val 0.1486426144838333
epoch 603:
Loss train 0.002012133370153606 val 0.14816483855247498
epoch 604:
Loss train 0.0020105648450553415 val 0.1536964774131775
epoch 605:
Loss train 0.002013957227580249 val 0.14736483991146088
epoch 606:
Loss train 0.0020708226710557938 val 0.14823223650455475
epoch 607:
Loss train 0.0019476202744990586 val 0.15227270126342773
epoch 608:
Loss train 0.0019398686597123741 val 0.14290374517440796
epoch 609:
Loss train 0.0020496798902750017 val 0.14608699083328247
epoch 610:
Loss train 0.002083740313537419 val 0.15028737485408783
epoch 611:
Loss train 0.0019736134307459 val 0.1469837725162506
epoch 612:
Loss train 0.0019319475470110773 val 0.14882658421993256
epoch 613:
Loss train 0.0019712087251245974 val 0.15242347121238708
epoch 614:
Loss train 0.0019428011626005173 val 0.14364561438560486
epoch 615:
Loss train 0.0020233078468590974 val 0.13476337492465973
epoch 616:
Loss train 0.0020317101422697307 val 0.1447516

sampling loop time step: 100%|██████████| 1000/1000 [00:41<00:00, 24.34it/s]


epoch 702:
Loss train 0.0018780720625072719 val 0.14889372885227203
epoch 703:
Loss train 0.0018082437394186855 val 0.15052905678749084
epoch 704:
Loss train 0.20390807277988643 val 0.14602351188659668
epoch 705:
Loss train 0.0018700958341360092 val 0.14462022483348846
epoch 706:
Loss train 0.001897739932872355 val 0.15038061141967773
epoch 707:
Loss train 0.0018663711836561562 val 0.15377874672412872
epoch 708:
Loss train 0.001819675898179412 val 0.1479349434375763
epoch 709:
Loss train 0.001858652425929904 val 0.15154646337032318
epoch 710:
Loss train 0.0018654292086139322 val 0.15192177891731262
epoch 711:
Loss train 0.0018412195891141891 val 0.15523923933506012
epoch 712:
Loss train 0.0018521779524162411 val 0.14984703063964844
epoch 713:
Loss train 0.0017623327076435088 val 0.14952488243579865
epoch 714:
Loss train 0.0018070061523467302 val 0.1559743732213974
epoch 715:
Loss train 0.0018167939027771353 val 0.15186168253421783
epoch 716:
Loss train 0.0018246699776500463 val 0.15549

sampling loop time step: 100%|██████████| 1000/1000 [00:41<00:00, 24.34it/s]


epoch 802:
Loss train 0.0017339678881689907 val 0.15124759078025818
epoch 803:
Loss train 0.001695038340985775 val 0.1509517878293991
epoch 804:
Loss train 0.0016593516115099192 val 0.1624843180179596
epoch 805:
Loss train 0.0016832643570378423 val 0.15924693644046783
epoch 806:
Loss train 0.0017579307118430734 val 0.15793012082576752
epoch 807:
Loss train 0.001722444494254887 val 0.15176518261432648
epoch 808:
Loss train 0.0017574969111010432 val 0.15235337615013123
epoch 809:
Loss train 0.0017059246180579067 val 0.15495409071445465
epoch 810:
Loss train 0.0018163103740662337 val 0.15893042087554932
epoch 811:
Loss train 0.0017268818197771908 val 0.15285667777061462
epoch 812:
Loss train 0.0017896259939298035 val 0.15540799498558044
epoch 813:
Loss train 0.0016998629653826355 val 0.1537809520959854
epoch 814:
Loss train 0.0016776089705526828 val 0.156813845038414
epoch 815:
Loss train 0.001722911319695413 val 0.15452471375465393
epoch 816:
Loss train 0.0017661166358739137 val 0.154779

sampling loop time step: 100%|██████████| 1000/1000 [00:41<00:00, 24.34it/s]


epoch 902:
Loss train 0.0015547755807638168 val 0.1505279839038849
epoch 903:
Loss train 0.0016641313079744577 val 0.15718239545822144
epoch 904:
Loss train 0.0016751339565962552 val 0.15656980872154236
epoch 905:
Loss train 0.001621507448144257 val 0.14594388008117676
epoch 906:
Loss train 0.001595003355294466 val 0.14666451513767242
epoch 907:
Loss train 0.0016635555867105722 val 0.152962788939476
epoch 908:
Loss train 0.0015704177422448992 val 0.15788963437080383
epoch 909:
Loss train 0.001634200258180499 val 0.15193381905555725
epoch 910:
Loss train 0.0015874839220196009 val 0.15949764847755432
epoch 911:
Loss train 0.001571672074496746 val 0.1528662145137787
epoch 912:
Loss train 0.0016352103427052498 val 0.15462033450603485
epoch 913:
Loss train 0.0016193594196811318 val 0.15323230624198914
epoch 914:
Loss train 0.0015919493576511741 val 0.15612585842609406
epoch 915:
Loss train 0.0015674020079895854 val 0.1542551964521408
epoch 916:
Loss train 0.001662036132067442 val 0.15779165

sampling loop time step: 100%|██████████| 1000/1000 [00:41<00:00, 24.27it/s]


epoch 1002:
Loss train 0.0015119687505066395 val 0.16629712283611298
epoch 1003:
Loss train 0.0015930710565298797 val 0.15521475672721863
epoch 1004:
Loss train 0.0015370525941252708 val 0.16002029180526733
epoch 1005:
Loss train 0.0014674914488568903 val 0.1596078723669052
epoch 1006:
Loss train 0.0015139001673087477 val 0.16359573602676392
epoch 1007:
Loss train 0.0014771541142836212 val 0.16047263145446777
epoch 1008:
Loss train 0.0015784609531983734 val 0.1667172610759735
epoch 1009:
Loss train 0.0015269925305619837 val 0.15235790610313416
epoch 1010:
Loss train 0.0015519520444795488 val 0.1600360870361328
epoch 1011:
Loss train 0.001533260243013501 val 0.154218852519989
epoch 1012:
Loss train 0.0015023225713521242 val 0.1583917886018753
epoch 1013:
Loss train 0.0015292229345068335 val 0.15360954403877258
epoch 1014:


KeyboardInterrupt: 

In [None]:
save_model(best_model_dict, best_model_shadow, best_model_opt, os.path.join(logging_dir, "weights/", "best_model.pth"))

In [None]:
load_checkpoint(model, ema, optim, os.path.join(logging_dir, "weights/", "best_model.pth"))

In [None]:
paths = []
for kde_plot in os.listdir(os.path.join(logging_dir, "kde/")):
    paths.append(os.path.join(logging_dir, "kde/", kde_plot))

make_gif_from_images(paths, os.path.join(logging_dir, "kde/", "kde_progression.gif"))

# Evaluation

In [None]:
real_data_test, real_cond_data_test = next(iter(test_loader))

In [None]:
with torch.no_grad():
    samples = ddpm.sample(batch_size, real_cond_data_test.to(device))
    samples = samples.cpu().numpy()

print(f"Samples shape: {samples.shape}")

In [None]:
vizual_comparison(samples, real_data_test, os.path.join(logging_dir, "viz/", "pca_umap_tsne_all_batches.png"), use_all_data=True);

In [None]:
vizual_comparison(samples, real_data_test, os.path.join(logging_dir, "viz/", "pca_umap_tsne_per_batch.png"));

In [None]:
mmd_histogram_per_customer(samples, real_data_test, fpath=os.path.join(logging_dir, "mmd/", "mmd.png"))

In [None]:
plot_jsd_per_customer(samples, real_data_test, os.path.join(logging_dir, "jsd/", "jsd.png"))

In [None]:
plot_kde_samples(samples, real_data_test, fpath=os.path.join(logging_dir, "kde/", f"kde.png"))

In [None]:
batch_idx = np.random.randint(0, 256) 
customer_indices = np.random.randint(0, 15, size=15)
fig, axs = plt.subplots(2, 1, figsize=(15, 10), sharex=True)

with torch.no_grad():
    sample = ddpm.sample(1, real_cond_data_test.to(device)) 
    sample = sample.cpu().numpy() 

for i in customer_indices:
    axs[0].plot(sample[0, i], alpha=0.7, label=f'Customer {i}')
    axs[1].plot(real_data_test[0, i], alpha=0.7, label=f'Customer {i}')

axs[0].set_title("Generated Data")
axs[1].set_title("Real Data")

for ax in axs:
    ax.set_xlabel('Time')
    ax.set_ylabel('Value')
    ax.grid(True)

plt.tight_layout()
plt.savefig(os.path.join(logging_dir, "ts_sample/", "samples.png"))
plt.show()