In [None]:
from models.model import ImagenTime
from utils.utils_data import gen_dataloader
from utils.utils import restore_state
import sys
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from models.sampler import DiffusionProcess
%matplotlib inline
import torch

# Define args as needed for gen_dataloader
class Args:
	def __init__(self):
		self.batch_size = 69696 # check this 
		self.shuffle = True
		self.num_workers = 4
		self.dataset = "fmri"    
		self.device = "cuda"
		self.use_stft = True
		self.diffusion_steps = 18
		self.n_fft = 63
		self.hop_length = 23
		self.img_resolution = 32
		
		self.input_channels = 2
		self.unet_channels = 128
		self.ch_mult = [1,2,4,4]
		self.attn_resolution = [32,16,8]
		self.ema = True
		self.ema_warmup = 100
		self.logging_iter = 100
		self.learning_rate: 0.0003 #1e-4
		self.weight_decay: 0.00001 #1e-5
# 


args = Args()

train_loader, test_loader = gen_dataloader(args)
print("dataset ready")

./data/short_range/padded_fmri_set.pt
dataset ready


In [7]:
from abc import ABC, abstractmethod
import torch
import torchaudio.transforms as T
from utils.utils_data import MinMaxScaler, MinMaxArgs
import torch
import torch.nn as nn
from contextlib import contextmanager
from models.networks import EDMPrecond
from models.ema import LitEma





class TsImgEmbedder(ABC):
    """
    Abstract class for transforming time series to images and vice versa
    """

    def __init__(self, device, seq_len):
        self.device = device
        self.seq_len = seq_len

    @abstractmethod
    def ts_to_img(self, signal):
        """

        Args:
            signal: given time series

        Returns:
            image representation of the signal

        """
        pass

    @abstractmethod
    def img_to_ts(self, img):
        """

        Args:
            img: given generated image

        Returns:
            time series representation of the generated image
        """
        pass




class STFTEmbedder(TsImgEmbedder):
    """
    STFT transformation
    """

    def __init__(self, device, seq_len, n_fft, hop_length):
        super().__init__(device, seq_len)
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.min_real, self.max_real, self.min_imag, self.max_imag = None, None, None, None

    def cache_min_max_params(self, train_data):
        """
        Args:
            train_data: training timeseries dataset. shape: B*L*K
        this function initializes the min and max values for the real and imaginary parts.
        we'll use this function only once, before the training loop starts.
        """
        real, imag = self.stft_transform(train_data)
        # compute and cache min and max values
        real, min_real, max_real = MinMaxScaler(real.numpy(), True)
        imag, min_imag, max_imag = MinMaxScaler(imag.numpy(), True)
        self.min_real, self.max_real = torch.Tensor(min_real), torch.Tensor(max_real)
        self.min_imag, self.max_imag = torch.Tensor(min_imag), torch.Tensor(max_imag)

    def stft_transform(self, data):
        """
        Args:
            data: time series data. Shape: B*L*K
        Returns:
            real and imaginary parts of the STFT transformation
        """
        data = torch.permute(data, (0, 2, 1))  # we permute to match requirements of torchaudio.transforms.Spectrogram
        spec = T.Spectrogram(n_fft=self.n_fft, hop_length=self.hop_length, center=True, power=None).to(data.device)
        transformed_data = spec(data)
        return transformed_data.real, transformed_data.imag

    def ts_to_img(self, signal):
        assert self.min_real is not None, "use init_norm_args() to compute scaling arguments"
        # convert to complex spectrogram
        real, imag = self.stft_transform(signal)
        # MinMax scaling
        real = (MinMaxArgs(real, self.min_real.to(self.device), self.max_real.to(self.device)) - 0.5) * 2
        imag = (MinMaxArgs(imag, self.min_imag.to(self.device), self.max_imag.to(self.device)) - 0.5) * 2
        # stack real and imag parts
        stft_out = torch.cat((real, imag), dim=1)

        


        return stft_out

    def img_to_ts(self, x_image):
        n_fft = self.n_fft
        hop_length, length = self.hop_length, self.seq_len
        min_real, max_real, min_imag, max_imag = self.min_real.to(
            self.device), self.max_real.to(
            self.device), \
            self.min_imag.to(self.device), self.max_imag.to(
            self.device)
        # -- combine real and imaginary parts --
        split = torch.split(x_image, x_image.shape[1] // 2,
                            dim=1)  # x_image.shape[1] is twice the size of the original dim

        real, imag = split[0], split[1]
        
        print("x_image shape:", x_image.shape)
        print("real shape:", real.shape)
        print("imag shape:", imag.shape)
        print("max_real shape:", max_real.shape)
        print("min_real shape:", min_real.shape)

        unnormalized_real = ((real / 2) + 0.5) * (max_real - min_real) + min_real
        unnormalized_imag = ((imag / 2) + 0.5) * (max_imag - min_imag) + min_imag
        unnormalized_stft = torch.complex(unnormalized_real, unnormalized_imag)
        # -- inverse stft --
        ispec = T.InverseSpectrogram(n_fft=n_fft, hop_length=hop_length, center=True).to(self.device)

        x_time_series = ispec(unnormalized_stft, length)

        return torch.permute(x_time_series, (0, 2, 1))  # B*L*K(C)


In [19]:
class ImagenTime(nn.Module):
    def __init__(self, args, device):
        '''
        beta_1    : beta_1 of diffusion process
        beta_T    : beta_T of diffusion process
        T         : Diffusion Steps
        '''

        super().__init__()
        self.P_mean = -1.2
        self.P_std = 1.2
        self.sigma_data = 0.5
        self.sigma_min = 0.002
        self.sigma_max = 80
        self.rho = 7
        self.T = args.diffusion_steps

        self.device = device
        self.net = EDMPrecond(args.img_resolution, args.input_channels, channel_mult=args.ch_mult,
                              model_channels=args.unet_channels, attn_resolutions=args.attn_resolution)

        # delay embedding is used
        if not args.use_stft:
            self.delay = args.delay
            self.embedding = args.embedding
            self.seq_len = args.seq_len

            # NOTE: added this
            # self.ts_img = DelayEmbedder(self.device, args.seq_len, args.delay, args.embedding)
        else:
            self.ts_img = STFTEmbedder(self.device, args.seq_len, args.n_fft, args.hop_length)

        if args.ema:
            self.use_ema = True
            self.model_ema = LitEma(self.net, decay=0.9999, use_num_upates=True, warmup=args.ema_warmup)
        else:
            self.use_ema = False

    def ts_to_img(self, signal, pad_val=None):
        """
        Args:
            signal: signal to convert to image
            pad_val: value to pad the image with, if delay embedding is used. Do not use for STFT embedding

        """
        # pad_val is used only for delay embedding, as the value to pad the image with
        # when creating the mask, we need to use 1 as padding value
        # if pad_val is given, it is used to overwrite the default value of 0

        # print(pad_val)


        return self.ts_img.ts_to_img(signal, True, pad_val) if pad_val else self.ts_img.ts_to_img(signal)

    def img_to_ts(self, img):
        return self.ts_img.img_to_ts(img)

    # init the min and max values for the STFTEmbedder, this function must be called before the training loop starts
    def init_stft_embedder(self, train_loader):
        """
        Args:
            train_loader: training data

        caches min and max values for the real and imaginary parts
        of the STFT transformation, which will be used for normalization.
        """
        assert type(self.ts_img) == STFTEmbedder, "You must use the STFTEmbedder to initialize the min and max values"
        data = []
        for i, data_batch in enumerate(train_loader):
            data.append(data_batch[0])
        self.ts_img.cache_min_max_params(torch.cat(data, dim=0))

    def loss_fn(self, x):
        '''
        x          : real data if idx==None else perturbation data
        idx        : if None (training phase), we perturbed random index.
        '''

        to_log = {}

        output, weight = self.forward(x)

        # denoising matching term
        # loss = weight * ((output - x) ** 2)
        loss = (weight * (output - x).square()).mean()
        to_log['karras loss'] = loss.detach().item()

        return loss, to_log

    def loss_fn_impute(self, x, mask):
        '''
        x          : real data if idx==None else perturbation data
        idx        : if None (training phase), we perturbed random index.
        '''

        to_log = {}
        output, weight = self.forward_impute(x, mask)
        x = self.unpad(x * (1 - mask), x.shape)
        output = self.unpad(output * (1 - mask), x.shape)
        loss = (weight * (output - x).square()).mean()
        to_log['karras loss'] = loss.detach().item()

        return loss, to_log


    def forward(self, x, labels=None, augment_pipe=None):

        rnd_normal = torch.randn([x.shape[0], 1, 1, 1], device=x.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        y, augment_labels = augment_pipe(x) if augment_pipe is not None else (x, None)
        n = torch.randn_like(y) * sigma
        D_yn = self.net(y + n, sigma, labels, augment_labels=augment_labels)
        return D_yn, weight

    def forward_impute(self, x, mask, labels=None, augment_pipe=None):

        rnd_normal = torch.randn([x.shape[0], 1, 1, 1], device=x.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2

        # noisy impute part
        n = torch.randn_like(x) * sigma
        noise_impute = n * (1 - mask)
        x_to_impute = x * (1 - mask) + noise_impute

        # clear image
        x = x * mask
        y, augment_labels = augment_pipe(x) if augment_pipe is not None else (x, None)

        D_yn = self.net(y + x_to_impute, sigma, labels, augment_labels=augment_labels)
        return D_yn, weight

    def forward_forecast(self, past, future, labels=None, augment_pipe=None):
        s, e = past.shape[-1], future.shape[-1]
        rnd_normal = torch.randn([past.shape[0], 1, 1, 1], device=past.device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        y, augment_labels = augment_pipe(past) if augment_pipe is not None else (past, None)
        n = torch.randn_like(future) * sigma
        full_seq = self.pad_f(torch.cat([past, future + n], dim=-1))
        D_yn = self.net(full_seq, sigma, labels, augment_labels=augment_labels)[..., s:(s + e)]
        return D_yn, weight

    def pad_f(self, x):
        """
        Pads the input tensor x to make it square along the last two dimensions.
        """
        _, _, cols, rows = x.shape
        max_side = max(32, rows)
        padding = (
            0, max_side - rows, 0, 0)  # Padding format: (pad_left, pad_right, pad_top, pad_bottom)

        # Padding the last two dimensions to make them square
        x_padded = torch.nn.functional.pad(x, padding, mode='constant', value=0)
        return x_padded

    def unpad(self, x, original_shape):
        """
        Removes the padding from the tensor x to get back to its original shape.
        """
        _, _, original_cols, original_rows = original_shape
        return x[:, :, :original_cols, :original_rows]

    @contextmanager
    def ema_scope(self, context=None):
        """
        Context manager to temporarily switch to EMA weights during inference.
        Args:
            context: some string to print when switching to EMA weights

        Returns:

        """
        if self.use_ema:
            self.model_ema.store(self.net.parameters())
            self.model_ema.copy_to(self.net)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.net.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

    def on_train_batch_end(self, *args):
        """
        this function updates the EMA model, if it is used
        Args:
            *args:

        Returns:

        """
        if self.use_ema:
            self.model_ema(self.net)


In [20]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.device = device

model = ImagenTime(args=args, device=args.device).to(args.device)
if args.use_stft:
    model.init_stft_embedder(train_loader)
    print("STFT embedder initialized")

args.learning_rate=  0.0003 #1e-4
args.weight_decay=  0.00001 #1e-4
args.resume = False
args.epochs = 1000
args.beta1 = 1e-05
args.betaT = 0.01
args.deterministic = False



#optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
state = dict(model=model, epoch=0)
init_epoch = 0

# restore checkpoint
if args.resume:
    ema_model = model.model_ema if args.ema else None # load ema model if available
    init_epoch = restore_state(args, state, ema_model=ema_model)


STFT embedder initialized


In [21]:
for epoch in range(init_epoch, args.epochs):
    model.train()
    model.epoch = epoch

    # --- train loop ---
    for i, data in enumerate(train_loader, 1):
        x_ts = data[0].to(args.device)  # x_ts contains the time series batch

  

        x_img = model.ts_to_img(x_ts)



        # # -HeatMap Features Vs Time Steps - p0
        # data_squeezed = x_ts.squeeze(-1)  # Remove the singleton dimension
        # plt.figure(figsize=(12, 6))
        # sns.heatmap(data_squeezed.numpy(), cmap="viridis", cbar=True)

        # # Add labels and title
        # plt.xlabel("Feature Index")
        # plt.ylabel("Time Steps")
        # plt.title("Features vs. Time Steps (Heatmap)")
        # plt.savefig('Ts_heatmap_fmri_pre.png')
        # plt.show()



        # - Take mean of last dimension method
        # Step 1: Reshape the data to (17, 264, 264)
        # data_reshaped = x_ts.view(17, 264, 264)  # Reshaping (69696 -> 264 x 264)

        # # Step 2: Compute the mean along the last dimension (mean of each 264 group)
        # mean_features = data_reshaped.mean(dim=2)  # Shape: (17, 264)

        # # Step 3: Plot the 264 lines
        # plt.figure(figsize=(12, 8))
        # for feature_idx in range(264):
        #     plt.plot(range(17), mean_features[:, feature_idx].numpy(), label=f"Group {feature_idx + 1}")

        # # Add labels, title, and grid
        # plt.xlabel("Time Steps")
        # plt.ylabel("Mean Feature Value")
        # plt.title("Mean of 264 Groups of Features Across Time Steps")
        # plt.grid(True)

        # Optional: Add a legend (only if you want to show it for some groups)
        # plt.legend(loc="upper right", fontsize="small", ncol=2)




        # -Density plot

        # # Select a specific time step
        # time_step_idx = 0
        # features = x_ts[time_step_idx, :, 0].numpy()  # Extract all features at the selected time step
        # time_steps_to_compare = [0, 8, 16]

        # plt.figure(figsize=(12, 6))
        # for t in time_steps_to_compare:
        #     features = x_ts[t, :, 0].numpy()
        #     sns.kdeplot(features, label=f"Time Step {t}", alpha=0.6)

        # plt.xlabel("Feature Values")
        # plt.ylabel("Density")
        # plt.title("Feature Value Distributions Across Time Steps")
        # plt.legend()
        # plt.grid(True)
        # plt.show()



        # ts -> img done
        
        # img_sample = x_img[0, 0, :, :]  # Shape: (height, width)
        # # Convert to NumPy for visualization
        # img_numpy = img_sample.cpu().detach().numpy()
        # # Plot the spectrogram
        # plt.figure(figsize=(10, 5))
        # plt.imshow(img_numpy, aspect='auto', origin='lower', cmap='viridis')
        # plt.colorbar(label='Intensity')
        # plt.title("Spectrogram (Time vs Frequency)")
        # plt.xlabel("Time Steps")
        # plt.ylabel("Frequency Bins")
        # plt.savefig('spectrogram_fmri.png')
        # plt.show()

        
    # break


        # optimizer.zero_grad()
        # loss = model.loss_fn(x_img)
        # if len(loss) == 2:
        #     loss, to_log = loss
        #     for key, value in to_log.items():
        #         print(f'train/{key}', value, epoch)

        # loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
        # optimizer.step()
        # model.on_train_batch_end()
        break

    # # --- evaluation loop ---
    # if epoch % args.logging_iter == 0:
    #     gen_sig = []
    #     real_sig = []
    #     model.eval()
    #     with torch.no_grad():
    #         with model.ema_scope():
    #             process = DiffusionProcess(args, model.net,
    #                                         (args.input_channels, args.img_resolution, args.img_resolution))
    #             for data in tqdm(test_loader):
    #                 # sample from the model
    #                 x_img_sampled = process.sampling(sampling_number=data[0].shape[0])
    #                 # --- convert to time series --
    #                 x_ts = model.img_to_ts(x_img_sampled)
                    

    #                 break
    break






None
