In [1]:
import os
import fim
import torch
import numpy as np

from torch import nn
from torch.utils.data import DataLoader
from fim.data.dataloaders import (
    DataLoaderFactory
)
from fim.utils.helper import (
    GenericConfig, 
    expand_params, 
    load_yaml
)

from fim.data.datasets import FIMSDEDataset,FIMSDEDatabatchTuple
from fim.data.config_dataclasses import FIMDatasetConfig
from fim.models.blocks import ModelFactory
from fim.models.config_dataclasses import FIMSDEConfig
from fim.data.config_dataclasses import FIMDatasetConfig

from fim.models.blocks.base import (
    Mlp,
    TimeEncoding,
    TransformerModel
)

In [2]:
# Define model and data
parameters_yaml = r"C:\Users\cesar\Desktop\Projects\FoundationModels\FIM\configs\train\fim-sde\fim-train-patrick.yaml"

config = load_yaml(parameters_yaml,return_object=True)
torch.manual_seed(int(config.experiment.seed))
torch.cuda.manual_seed(int(config.experiment.seed))
np.random.seed(int(config.experiment.seed))
torch.cuda.empty_cache()
device_map = config.experiment.device_map

config = config.to_dict()
dataloader = DataLoaderFactory.create(**config["dataset"])
if hasattr(dataloader,"update_kwargs"):
    # fim model requieres that config is updated after loading the data
    dataloader.update_kwargs(config)
databatch:FIMSDEDatabatchTuple = next(dataloader.train_it.__iter__())
model = ModelFactory.create(config,device_map=device_map,resume=False)

Max Hypercube Size: 1024
Max Dimension: 3
Max Num Steps: 128
Max Hypercube Size: 1024
Max Dimension: 3
Max Num Steps: 128
Max Hypercube Size: 1024
Max Dimension: 3
Max Num Steps: 128


In [3]:
###################################################################################
  # To change
"""
  # phi_0^t
  temporal_embedding_size: &temporal_embedding_size 19

  # phi_0^s
  spatial_embedding_size: &spatial_embedding_size 19
  spatial_embedding_hidden_layers: &spatial_embedding_hidden_layers null #  if null, this will just be dense layer

  # psi_1
  sequence_encoding_transformer_hidden_size: &sequence_encoding_transformer_hidden_size 28 
  sequence_encoding_transformer_heads: &sequence_encoding_transformer_heads 2
  sequence_encoding_transformer_layers: &sequence_encoding_transformer_layers 2

  # Omega_1
  combining_transformer_hidden_size: &combining_transformer_hidden_size 28 
  combining_transformer_heads: &combining_transformer_heads 2
  combining_transformer_layers: &combining_transformer_layers 1

  # phi_1
  trunk_net_size: &trunk_net_size 28 
  trunk_net_hidden_layers: &trunk_net_hidden_layers null
"""

'\n  # phi_0^t\n  temporal_embedding_size: &temporal_embedding_size 19\n\n  # phi_0^s\n  spatial_embedding_size: &spatial_embedding_size 19\n  spatial_embedding_hidden_layers: &spatial_embedding_hidden_layers null #  if null, this will just be dense layer\n\n  # psi_1\n  sequence_encoding_transformer_hidden_size: &sequence_encoding_transformer_hidden_size 28 \n  sequence_encoding_transformer_heads: &sequence_encoding_transformer_heads 2\n  sequence_encoding_transformer_layers: &sequence_encoding_transformer_layers 2\n\n  # Omega_1\n  combining_transformer_hidden_size: &combining_transformer_hidden_size 28 \n  combining_transformer_heads: &combining_transformer_heads 2\n  combining_transformer_layers: &combining_transformer_layers 1\n\n  # phi_1\n  trunk_net_size: &trunk_net_size 28 \n  trunk_net_hidden_layers: &trunk_net_hidden_layers null\n'

In [83]:
# definition of the model parameters
model_params:FIMSDEConfig = model.model_params
data_params:FIMDatasetConfig = model.data_params

# Define different versions
x_dimension = data_params.max_dimension
x_dimension_full = x_dimension*3 # we encode the difference and its square
spatial_plus_time_encoding = model_params.temporal_embedding_size + model_params.spatial_embedding_size
psi_1_tokes_dim = model_params.sequence_encoding_tokenizer*model_params.sequence_encoding_transformer_heads

# basic embedding
phi_0t = TimeEncoding(model_params.temporal_embedding_size)
phi_0x = Mlp(
    in_features=x_dimension_full,
    out_features=model_params.spatial_embedding_size,
    hidden_layers=model_params.spatial_embedding_hidden_layers
)

# trunk network
trunk = Mlp(
    in_features=x_dimension,
    out_features=psi_1_tokes_dim,
    hidden_layers=model_params.trunk_net_hidden_layers
)

#ensures that the embbeding that is sent to the transformer is a multiple of the number of heads
phi_xt = nn.Linear(spatial_plus_time_encoding,
                   psi_1_tokes_dim)

# path transformer (causal encoding of paths)
psi_1 = TransformerModel(input_dim=psi_1_tokes_dim, 
                         nhead=model_params.sequence_encoding_transformer_heads, 
                         hidden_dim=model_params.sequence_encoding_transformer_hidden_size, 
                         nlayers=model_params.sequence_encoding_transformer_layers)

# time attention
omega_1 = nn.MultiheadAttention(
    psi_1_tokes_dim, 
    model_params.combining_transformer_heads,
    batch_first=True,
)

# path attention
path_queries = nn.Parameter(torch.randn(1, psi_1_tokes_dim))

omega_2 = nn.MultiheadAttention(
    psi_1_tokes_dim, 
    model_params.combining_transformer_heads,
    batch_first=True,
)

# drift_head =
# var_drift_head =
# diffusion_head = 
# var_diffusion_head = 



In [78]:
psi_1_tokes_dim

10

In [79]:
#model_params.trunk_net_size

In [80]:
#model_params.combining_transformer_heads

In [81]:
B,P,T,D,_ = databatch.obs_values.shape
G = databatch.locations.size(1)

# include the square of the difference
x_full = torch.concat([databatch.obs_values,databatch.obs_values[:,:,:,:,0].unsqueeze(-1)**2],dim=-1)
x_flattened = x_full.view(x_full.shape[0], x_full.shape[1], x_full.shape[2], -1)
spatial_encoding = phi_0x(x_flattened) # [B,P,T,spatial_embedding_size]
time_encoding = phi_0t(databatch.obs_times) # [B,P,T,temporal_embedding_size]

# trunk 
trunk_encoding = trunk(databatch.locations) #[B,H,trunk_dim]
trunk_encoding = trunk_encoding[:,None,:,:].repeat(1,P,1,1)  # [B,P,H,trunk_size]
trunk_encoding = trunk_encoding.view(B*P,G,-1)

# embbedded input
U =  torch.cat([spatial_encoding,time_encoding],dim=-1) #  [B,P,T,spatial_plus_time_encoding]
U = phi_xt(U) #  [B,P,T,psi_1_tokes_dim] 

# TRANSFORMER THAT CREATES A REPRESENTATION FOR THE PATHS
U = U.view(B*P,T,psi_1_tokes_dim)
H = psi_1(torch.transpose(U,0,1))  # [T,B*P,psi_1_tokes_dim]
H = torch.transpose(H,0,1) # [B*P,T,psi_1_tokes_dim]

# Attention on Time -> One representation per path
hx,_ = omega_1(trunk_encoding,H,H) # [B*P,H,psi_1_tokes_dim]
hx = hx.view(B,P,G,-1) # [B,P,G,psi_1_tokes_dim]

# Attention on Paths -> One representation per expression
hx = hx.transpose(1,2).reshape(G*B,P,-1) # [B*G,P,psi_1_tokes_dim]
path_queries_ = path_queries[None,:,:].repeat(G*B,1,1)
bx,_ = omega_2(path_queries_,hx,hx)
bx = bx.view(B,G,-1)

In [82]:

#omega_1(trunk_encoding,H,H)

In [75]:
bx.shape

torch.Size([2, 1024, 10])

In [55]:
hx.shape

torch.Size([2, 300, 1024, 10])

In [35]:
trunk_encoding.shape

torch.Size([600, 1024, 10])

In [18]:
hx.shape

torch.Size([1024, 600, 10])

In [None]:
# Reshape queries to match the attention requirements
# H = self.psi1(torch.transpose(H,0,1)) # (seq_lenght,batch_size,encoding0_dim)

# tx = tx.reshape(num_hyper, batch_size, self.encoding0_dim)  # Shape: (num_hyper, batch_size, encoding0_dim)

# Representation per path
# attn_output, _ = multihead_attn(queries[:,None,:].repeat(1,batch_size,1), H, H) # Shape: (1, batch_size, query_dim)
#attn_output, _ = self.omega_1(tx, H, H) # Shape: (num_hyper, batch_size, query_dim)
#attn_output = torch.transpose(attn_output,1,0) # Shape: (num_hyper, batch_size, query_dim)
#attn_output = attn_output.reshape(num_hyper*batch_size,self.encoding0_dim)


In [11]:
#H  = H.reshape(batch_size,num_steps,self.encoding0_dim) 
#H = self.psi1(torch.transpose(H,0,1)) # (seq_lenght,batch_size,encoding0_dim)

In [13]:
"""Defines all the nn Modules need for the architecture
self.encoding0_dim = params.dim_time + params.x0_out_features
self.phi_t0 = TimeEncoding(params.dim_time)

self.phi_x0 = Mlp(in_features=params.max_dimension,
                    out_features=params.x0_out_features,
                    hidden_layers=params.x0_hidden_layers,
                    output_act=nn.SiLU())

self.phi_1 = Mlp(in_features=params.max_dimension,
                    out_features=params.max_dimension,
                    hidden_layers=params.x0_hidden_layers)

self.phi_2 = Mlp(in_features=self.encoding0_dim,
                    out_features=params.max_dimension,
                    hidden_layers=params.x0_hidden_layers)

self.psi1 = TransformerModel(input_dim=self.encoding0_dim, 
                                nhead=params.n_heads, 
                                hidden_dim=params.psi1_hidden_dim, 
                                nlayers=params.psi1_nlayers)

#self.queries = nn.Parameter(torch.randn(1, params.encoding0_dim))
self.query_1x = QueryGenerator(input_dim=params.max_dimension,
                                query_dim=params.encoding0_dim)

self.query_1 =  StaticQuery(num_steps=params.max_num_steps,
                    query_dim=params.encoding0_dim)

# Create the MultiheadAttention module
self.omega_1 = nn.MultiheadAttention(self.encoding0_dim, params.n_heads)
"""

'Defines all the nn Modules need for the architecture\nself.encoding0_dim = params.dim_time + params.x0_out_features\nself.phi_t0 = TimeEncoding(params.dim_time)\n\nself.phi_x0 = Mlp(in_features=params.max_dimension,\n                    out_features=params.x0_out_features,\n                    hidden_layers=params.x0_hidden_layers,\n                    output_act=nn.SiLU())\n\nself.phi_1 = Mlp(in_features=params.max_dimension,\n                    out_features=params.max_dimension,\n                    hidden_layers=params.x0_hidden_layers)\n\nself.phi_2 = Mlp(in_features=self.encoding0_dim,\n                    out_features=params.max_dimension,\n                    hidden_layers=params.x0_hidden_layers)\n\nself.psi1 = TransformerModel(input_dim=self.encoding0_dim, \n                                nhead=params.n_heads, \n                                hidden_dim=params.psi1_hidden_dim, \n                                nlayers=params.psi1_nlayers)\n\n#self.queries = nn.Parameter(to