In [23]:
import torch
import pandas as pd
import numpy as np
import torch.nn as nn
from modules import UNet_conditional
from diffusion import *
from utils import *
from torch.amp import autocast

In [24]:
path = "models/test/ema_ckpt.pt"
print("Loading ", path)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

model = UNet_conditional(length=1024,
                         feat_num=3, 
                         device=device).to(device)
ckpt = torch.load(path, map_location=device)
model.load_state_dict(ckpt)
sampler = SpacedDiffusion(beta_start=1e-4, 
                          beta_end=0.02, 
                          noise_steps=1000, 
                          section_counts=[40], 
                          length=1024, 
                          device=device, 
                          rescale_timesteps=False)    
    

Loading  models/test/ema_ckpt.pt
Using device: cpu



In [25]:
def predict(model, sampler, test_dl, device, n_samples=4):
    """
    Return predictions
    """
    x_real = []
    predictions = []
    
    model.eval()
    
    with torch.no_grad():
        for i, data in enumerate(tqdm(test_dl, desc="Testing loop")):
        #for i, data in enumerate(test_dl):
            vectors = data['data'].to(device)
            settings = data['settings'].to(device)
            
            #with autocast(device_type=device, dtype=torch.float16):
            pred = sampler.ddim_sample_loop(model=model, 
                         y=settings, 
                         cfg_scale=1, 
                         device=device, 
                         eta=1, 
                         n=n_samples
                         )
            
            # we move predictions to cpu, in case they are stored on GPU
            x_real.extend(vectors.cpu().tolist())
            predictions.extend(pred.cpu().tolist())
                
    return x_real, predictions
    

In [26]:
def evaluate(model, sampler, device, test_csv_path, n_samples=4, batch_size = 4):
    """
    Evaluate predictions
    """
    # Load the test dataset
    x_test, y_test = get_data(test_csv_path)
    
    test_dataset = CustomDataset(x_test, y_test)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    x_real, predictions = predict(model, sampler, test_dataloader, device=device, n_samples=n_samples)
    mse = nn.MSELoss()
    mse_errors = []
    
    for i, pred in enumerate(predictions):
        err = mse(pred, x_real[i])
        mse_errors.append(err)
        
    return mse_errors
    
mse_errors = evaluate(model, sampler, device, "../data/test_data.csv")

Testing loop:   0%|          | 0/392 [00:00<?, ?it/s]
ddim sample loop:   0%|          | 0/40 [00:00<?, ?it/s][A
ddim sample loop:   2%|▎         | 1/40 [00:00<00:13,  2.94it/s][A
ddim sample loop:   5%|▌         | 2/40 [00:00<00:14,  2.67it/s][A
ddim sample loop:   8%|▊         | 3/40 [00:01<00:14,  2.62it/s][A
ddim sample loop:  10%|█         | 4/40 [00:01<00:12,  2.82it/s][A
ddim sample loop:  12%|█▎        | 5/40 [00:01<00:11,  2.95it/s][A
ddim sample loop:  15%|█▌        | 6/40 [00:02<00:11,  3.03it/s][A
ddim sample loop:  18%|█▊        | 7/40 [00:02<00:10,  3.10it/s][A
ddim sample loop:  20%|██        | 8/40 [00:02<00:10,  3.16it/s][A
ddim sample loop:  22%|██▎       | 9/40 [00:02<00:09,  3.16it/s][A
ddim sample loop:  25%|██▌       | 10/40 [00:03<00:09,  3.12it/s][A
ddim sample loop:  28%|██▊       | 11/40 [00:03<00:09,  3.16it/s][A
ddim sample loop:  30%|███       | 12/40 [00:03<00:08,  3.20it/s][A
ddim sample loop:  32%|███▎      | 13/40 [00:04<00:08,  3.24it/s][

KeyboardInterrupt: 