In [5]:
import time
from typing import List, Optional, Union



import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader

from gluonts.core.component import validated
import wandb
import os
from tqdm.auto import tqdm
import argparse

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
import math

import torch
from torch import nn
import torch.nn.functional as F

class DiffusionEmbedding(nn.Module):
    def __init__(self, dim, proj_dim, max_steps=256):
        super().__init__()
        self.register_buffer(
            "embedding", self._build_embedding(dim, max_steps), persistent=False
        )
        self.projection1 = nn.Linear(dim * 2, proj_dim)
        self.projection2 = nn.Linear(proj_dim, proj_dim)

    def forward(self, diffusion_step):
        x = self.embedding[diffusion_step]
        x = self.projection1(x)
        x = F.silu(x)
        x = self.projection2(x)
        x = F.silu(x)
        return x

    def _build_embedding(self, dim, max_steps):
        steps = torch.arange(max_steps).unsqueeze(1)  # [T,1]
        dims = torch.arange(dim).unsqueeze(0)  # [1,dim]
        table = steps * 10.0 ** (dims * 4.0 / dim)  # [T,dim]
        table = torch.cat([torch.sin(table), torch.cos(table)], dim=1)
        return table
class ResidualBlock(nn.Module):
    def __init__(self, hidden_size, residual_channels, dilation):
        super().__init__()
        self.dilated_conv = nn.Conv1d(
            residual_channels,
            2 * residual_channels,
            3,
            padding=dilation,
            dilation=dilation,
            padding_mode="zeros",
        )
        self.diffusion_projection = nn.Linear(hidden_size, residual_channels)
        self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1)

        nn.init.kaiming_normal_(self.output_projection.weight)

    def forward(self, x, diffusion_step):
        diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)

        y = x + diffusion_step
        y = self.dilated_conv(y)

        gate, filter = torch.chunk(y, 2, dim=1)
        y = torch.sigmoid(gate) * torch.tanh(filter)

        y = self.output_projection(y)
        y = F.leaky_relu(y, 0.4)
        residual, skip = torch.chunk(y, 2, dim=1)
        return (x + residual) / math.sqrt(2.0), skip

class EpsilonThetaClass(nn.Module):
    def __init__(
        self,
        num_classes = 5,
        #cond_length,
        time_emb_dim=16,
        residual_layers=8,
        residual_channels=8,
        dilation_cycle_length=2,
        residual_hidden=64,
        class_emb_dim=10,
        target_dim=1,
        
    ):
        super().__init__()
        self.class_embedding = nn.Embedding(num_classes, class_emb_dim)
        
        self.input_projection = nn.Conv1d(
            1+class_emb_dim, residual_channels, 1, padding=2, padding_mode="zeros"
        )
        self.diffusion_embedding = DiffusionEmbedding(
            time_emb_dim, proj_dim=residual_hidden
        )
        self.residual_layers = nn.ModuleList(
            [
                ResidualBlock(
                    residual_channels=residual_channels,
                    dilation=1,
                    hidden_size=residual_hidden,
                )
                for i in range(residual_layers)
            ]
        )
        self.skip_projection = nn.Conv1d(residual_channels, residual_channels, 3)
        self.output_projection = nn.Conv1d(residual_channels, target_dim, 3)

        nn.init.kaiming_normal_(self.input_projection.weight)
        nn.init.kaiming_normal_(self.skip_projection.weight)
        nn.init.kaiming_normal_(self.output_projection.weight)

    def forward(self, inputs, time, class_labels):
        class_embeddings = self.class_embedding(class_labels)  # [batch_size, class_emb_dim]
        class_embeddings = class_embeddings.unsqueeze(2).expand(-1, -1, inputs.size(2))

        # Concatenate class embeddings with inputs
        inputs = torch.cat([inputs, class_embeddings], dim=1)

        x = self.input_projection(inputs)
        x = F.leaky_relu(x, 0.4)

        diffusion_step = self.diffusion_embedding(time)
        skip = []
        for layer in self.residual_layers:
            x, skip_connection = layer(x, diffusion_step)
            skip.append(skip_connection)

        x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
        x = self.skip_projection(x)
        x = F.leaky_relu(x, 0.4)
        x = self.output_projection(x)
        return x


In [9]:
from functools import partial
from inspect import isfunction

import numpy as np

import torch
from torch import nn, einsum
import torch.nn.functional as F


#  COND HAS BEEN TAKEN OUT FROM FUNCTIONS


def default(val, d):
    if val is not None:
        return val
    return d() if isfunction(d) else d


def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def noise_like(shape, device, repeat=False):
    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
        shape[0], *((1,) * (len(shape) - 1))
    )
    noise = lambda: torch.randn(shape, device=device)
    return repeat_noise() if repeat else noise()


def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule
    as proposed in https://openreview.model/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    x = np.linspace(0, timesteps, steps)
    alphas_cumprod = np.cos(((x / timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return np.clip(betas, 0, 0.999)


class GaussianDiffusionClass(nn.Module):
    def __init__(
        self,
        denoise_fn,#pass epsilon theta
        input_size,
        beta_end=0.1,
        diff_steps=100,
        loss_type="l2",
        betas=None,
        beta_schedule="linear",
    ):
        super().__init__()
        self.denoise_fn = denoise_fn
        self.input_size = input_size
        self.__scale = None

        if betas is not None:
            betas = (
                betas.detach().cpu().numpy()
                if isinstance(betas, torch.Tensor)
                else betas
            )
        else:
            if beta_schedule == "linear":
                betas = np.linspace(1e-4, beta_end, diff_steps)
            elif beta_schedule == "quad":
                betas = np.linspace(1e-4 ** 0.5, beta_end ** 0.5, diff_steps) ** 2
            elif beta_schedule == "const":
                betas = beta_end * np.ones(diff_steps)
            elif beta_schedule == "jsd":  # 1/T, 1/(T-1), 1/(T-2), ..., 1
                betas = 1.0 / np.linspace(diff_steps, 1, diff_steps)
            elif beta_schedule == "sigmoid":
                betas = np.linspace(-6, 6, diff_steps)
                betas = (beta_end - 1e-4) / (np.exp(-betas) + 1) + 1e-4
            elif beta_schedule == "cosine":
                betas = cosine_beta_schedule(diff_steps)
            else:
                raise NotImplementedError(beta_schedule)

        alphas = 1.0 - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])

        (timesteps,) = betas.shape
        self.num_timesteps = int(timesteps)
        self.loss_type = loss_type

        to_torch = partial(torch.tensor, dtype=torch.float32)

        self.register_buffer("betas", to_torch(betas))
        self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
        self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
        self.register_buffer(
            "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
        )
        self.register_buffer(
            "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
        )
        self.register_buffer(
            "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
        )
        self.register_buffer(
            "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
        )

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = (
            betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )
        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
        self.register_buffer("posterior_variance", to_torch(posterior_variance))
        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        self.register_buffer(
            "posterior_log_variance_clipped",
            to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
        )
        self.register_buffer(
            "posterior_mean_coef1",
            to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
        )
        self.register_buffer(
            "posterior_mean_coef2",
            to_torch(
                (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
            ),
        )

    @property
    def scale(self):
        return self.__scale

    @scale.setter
    def scale(self, scale):
        self.__scale = scale

    def q_mean_variance(self, x_start, t):
        mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
        variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape)
        log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
        return mean, variance, log_variance

    def predict_start_from_noise(self, x_t, t, noise):
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
            - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )

    def q_posterior(self, x_start, x_t, t):
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
            + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract(
            self.posterior_log_variance_clipped, t, x_t.shape
        )
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_mean_variance(self, x, t, class_labels, clip_denoised: bool):
        x_recon = self.predict_start_from_noise(
            x, t=t, noise=self.denoise_fn(x, t, class_labels)
        )

        if clip_denoised:
            x_recon.clamp_(-1.0, 1.0)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
            x_start=x_recon, x_t=x, t=t
        )
        return model_mean, posterior_variance, posterior_log_variance
    
    @torch.no_grad()
    def q_sample_loop(self,x_0, shape):
        device = self.betas.device

        b=shape[0]
        img=torch.empty(self.num_timesteps, *shape)
        for i in range(0, self.num_timesteps) :
            img[i]=self.q_sample(x_0, torch.full((b,), i, device=device, dtype=torch.long))
        return img

    @torch.no_grad()
    def p_sample(self, x, t, class_labels, clip_denoised=False, repeat_noise=False):
        b, *_, device = *x.shape, x.device
        model_mean, _, model_log_variance = self.p_mean_variance(
            x=x, t=t, class_labels = class_labels, clip_denoised=clip_denoised
        )
        noise = noise_like(x.shape, device, repeat_noise)
        # no noise when t == 0
        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

    @torch.no_grad()
    def p_sample_loop(self, x, class_labels):
        device = self.betas.device

        b = x.shape[0]
        img = torch.randn(x.shape, device=device)

        for i in reversed(range(0, self.num_timesteps)):
            img = self.p_sample(
                img, torch.full((b,),class_labels, i, device=device, dtype=torch.long)
            )
        return img
    

    @torch.no_grad()
    def sample(self, class_labels, sample_shape=torch.Size(), cond=None):
        if cond is not None:
            shape = cond.shape[:-1] + (self.input_size,)
            # TODO reshape cond to (B*T, 1, -1)
        else:
            shape = sample_shape
        x_hat = self.p_sample_loop(shape, class_labels, cond)  # TODO reshape x_hat to (B,T,-1)

        if self.scale is not None:
            x_hat *= self.scale
        return x_hat

    @torch.no_grad()
    def interpolate(self, x1, x2,class_labels, t=None, lam=0.5):
        b, *_, device = *x1.shape, x1.device
        t = default(t, self.num_timesteps - 1)

        assert x1.shape == x2.shape

        t_batched = torch.stack([torch.tensor(t, device=device)] * b)
        xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))

        img = (1 - lam) * xt1 + lam * xt2
        for i in reversed(range(0, t)):
            img = self.p_sample(
                img, torch.full((b,), class_labels,i, device=device, dtype=torch.long)
            )

        return img

    def q_sample(self, x_start, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
            + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )

    def p_losses(self, x_start, t, class_labels, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
        x_recon = self.denoise_fn(x_noisy, t, class_labels)

        if self.loss_type == "l1":
            loss = F.l1_loss(x_recon, noise)
        elif self.loss_type == "l2":
            loss = F.mse_loss(x_recon, noise)
        elif self.loss_type == "huber":
            loss = F.smooth_l1_loss(x_recon, noise)
        else:
            raise NotImplementedError()

        return loss

    def log_prob(self, x, class_labels,*args, **kwargs):
        if self.scale is not None:
            x /= self.scale

        B, T, _ = x.shape

        time = torch.randint(0, self.num_timesteps, (B * T,), device=x.device).long()
        loss = self.p_losses(
            x.reshape(B * T, 1, -1),  time, class_labels,*args, **kwargs
        )

        return loss
    
    
        


In [11]:
class Trainer:

    def __init__(
            self,
            net: GaussianDiffusionClass,
            epochs: int = 50,
            batch_size: int = 32,
            num_batches_per_epoch: int = 50,
            learning_rate: float = 1e-3,
            weight_decay: float = 1e-6,
            maximum_learning_rate: float = 1e-2,
            model_name : str = 'model',
            model_type : str ='torch',
            model_save_path : str = 'model_sav_path',
            input_size = [256],
           

            **kwargs,
    )->None:
        
        self.epochs = epochs
        self.batch_size = batch_size
        self.num_batches_per_epoch = num_batches_per_epoch
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.maximum_learning_rate = maximum_learning_rate
        self.model_name = model_name
        self.model_type = model_type
        self.model_save_path = model_save_path
        self.input_size = input_size
        self.net = net
        


    
    def __call__(
            self,
           
            train_iter: DataLoader,
 
    )->None:
        
        wandb.login()
        
        wandb.init(project="test_train_class")

        # Log hyperparameters and other configurations
        config = {
            'epochs': self.epochs,
            'batch_size': self.batch_size,
            'num_batches_per_epoch': self.num_batches_per_epoch,
            'learning_rate': self.learning_rate,
            'weight_decay': self.weight_decay,
            'maximum_learning_rate': self.maximum_learning_rate,
        }
        wandb.config.update(config)

        optimizer = Adam(
            self.net.parameters(), lr = self.learning_rate, weight_decay = self.weight_decay
        )

        lr_scheduler = OneCycleLR(
            optimizer,
            max_lr = self.maximum_learning_rate,
            steps_per_epoch = self.num_batches_per_epoch,
            epochs = self.epochs,
        ) 

        losses_t = []
        for epoch in range(self.epochs):
            tic = time.time()
            cumm_epoch_loss = 0.0

            with tqdm(train_iter, total=self.num_batches_per_epoch - 1) as it:
                for batch_no, data_entry in enumerate(it, start=1):
                    optimizer.zero_grad()
                    signals = data_entry['signals']
                    class_labels = data_entry['sc']  # Ensure this is included in your model's log_prob method
                    losses = self.net.log_prob(signals, class_labels=class_labels)
                    cumm_epoch_loss += losses.item()

                    avg_epoch_loss = cumm_epoch_loss / batch_no
                    it.set_postfix({"epoch": f"{epoch + 1}/{self.epochs}", "avg_loss": avg_epoch_loss}, refresh=False)

                    wandb.log({"train_loss": losses.item()})
                    losses.backward()
                    optimizer.step()
                
                

                    if self.num_batches_per_epoch == batch_no:
                        break
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": avg_epoch_loss,
                "learning_rate": optimizer.param_groups[0]['lr'],
                "gradient_norm": self.calculate_gradient_norm(self.net),
                "training_time_per_epoch": time.time() - tic,
            })
            losses_t.append(avg_epoch_loss)
        
        self.save_model_as_artifact(self.net)
            

    @staticmethod
    def calculate_gradient_norm(net):
        total_norm = 0.0
        for param in net.parameters():
            if param.grad is not None:
                param_norm = param.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        return total_norm
    
    def save_model_as_artifact(self, model):
        model.eval()  # Set the model to evaluation mode

        # Save the model state dictionary instead of using ONNX
        
        if not os.path.exists(self.model_save_path):
            os.makedirs(self.model_save_path)
        model_path = os.path.join(self.model_save_path, f'{self.model_name}.pth')
        torch.save(model.state_dict(), model_path)

        # Create an artifact for logging to wandb
        artifact = wandb.Artifact(self.model_name, type=self.model_type)
        
        # Add metadata to the artifact
        artifact.metadata = {
            'format': 'pytorch_state_dict',
            'model_type': self.model_type,
            'epochs': self.epochs,
            'batch_size': self.batch_size,
            'learning_rate': self.learning_rate,
            'weight_decay': self.weight_decay,
            'layers': [str(layer) for layer in model.children()],
        }
        
        # Add the model file to the artifact
        artifact.add_file(model_path)

        # Log the artifact to wandb
        wandb.log_artifact(artifact)


def custom_collate_fn(batch):
        """
        Custom collate function to reshape data into [batch size, channels, size].
        """
        # Assuming your signals are originally in the shape [size]
        # and you want to add a single channel dimension
        signals = torch.stack([item['signals'] for item in batch]).unsqueeze(1)  # Adds a channel dimension
        gt = torch.stack([item['gt'] for item in batch])
        sc = torch.stack([item['sc'] for item in batch])
        
        return {'signals': signals, 'gt': gt, 'sc': sc}

In [18]:
file_path = 'C:/Users/Alexia/datasets/train_set_5classes.pth'
dataset = torch.load(file_path)
epochs = 50
train_loader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=custom_collate_fn)
net = GaussianDiffusionClass(EpsilonThetaClass(), input_size = 256)

model_name = f'CLASS_batch{128}_lr{0.001}_e{epochs}'

# Initialize and configure wandb run
wandb.init(project="test_train_class", name=model_name, reinit=True)

trainer = Trainer(
    net = net,
    batch_size=128,
    learning_rate=0.001,
    model_name=model_name,
)
trainer(train_loader)

# Ensure the current wandb run is properly closed before the next
wandb.finish()



  0%|          | 0/49 [00:00<?, ?it/s]


IndexError: index out of range in self