In [2]:
%load_ext autoreload
%autoreload 2

from seqsketch.utils.config import Config
from pathlib import Path
import torch
import numpy as np

In [168]:
file = "test.yaml"

configurator = Config(
        config_file="seqsketch/configs/" + file )

config = configurator.get_config()
dataloader = configurator.get_dataloader()
#print(config.model.params.denoising_network.params)
model = configurator.get_model()
#params = config.model.params.denoising_network.params


In [None]:
#from seqsketch.models import StrokeDenoiser,StrokeEncoder, CrossAttentionDecoder

In [169]:
batch = next(iter(dataloader.train_dataloader()))
_,_,x,c, x_mask, c_mask = batch.values()

In [293]:
import torch
import torch.nn as nn

class CrossAttentionDecoder(nn.Module):
    def __init__(self, d_model=128, nhead=4, ff_dim=256, num_layers=3, seq_length=32):
        super().__init__()
        self.seq_length = seq_length  # L
        self.output_dim = 2  # For (x, y) coordinates

        # Transformer decoder layer
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=ff_dim, batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.combined_embedding = nn.Linear(2 * d_model, d_model) 
        # Linear projection to noise prediction
        self.final = nn.Linear(d_model, self.output_dim)

    def forward(self, x_emb, c_emb, t_emb):
        """
        Args:
            H: Target embedding (B, D) - noisy data.
            C: Condition embedding (B, D) - clean historical data.
            t: Time embedding (B, D) - time encoding.

        Returns:
            Noise prediction: (B, 1, L, 2)
        """
        
        _, M, _ = c_emb.size()

        # Prepare inputs for the decoder
        t_expanded = t_emb.unsqueeze(1).repeat(1, M, 1)  # Shape: (B, M, D)
        context = torch.cat([c_emb, t_expanded], dim=-1)  # Shape: (B, M, 2 * D)
        combined_context = self.combined_embedding(context)
        # Decoder output
        decoded = self.decoder(x_emb, combined_context)  # (B, L, D)
        pred = self.final(decoded) # (B, L, 2)
        return pred


In [294]:
from seqsketch.models.encoder import StrokeEncoder

encoder = StrokeEncoder()
x_emb = encoder(x,x_mask)
c_emb = encoder(c,c_mask)
t_emb = torch.randn((16,128))

In [295]:
decoder = CrossAttentionDecoder()
out = decoder(x_emb,c_emb,t_emb)

tensor(False)

In [None]:



class MultiStrokeEncoder(nn.Module):
    def __init__(self, input_dim=128, num_filters=64, kernel_size=3, stride=2, pool_stride=2,
                 nhead = 4, num_layers=3, mode = "tf"):
        super().__init__()

        # Convolutional Layers with pooling after each convolution
        self.conv1 = nn.Conv1d(input_dim, num_filters, kernel_size=kernel_size, stride=stride, padding=kernel_size//2)
        self.pool1 = nn.MaxPool1d(pool_stride)  # Pooling after first convolution
        
        self.conv2 = nn.Conv1d(num_filters, num_filters*2, kernel_size=kernel_size, stride=stride, padding=kernel_size//2)
        self.pool2 = nn.MaxPool1d(pool_stride)  # Pooling after second convolution
        
        # Transformer Encoder
        self.mode = mode
        transformer_encoder_layer = nn.TransformerEncoderLayer(d_model=num_filters*2, nhead=nhead, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(transformer_encoder_layer, num_layers=num_layers)
        self.rnn = nn.GRU(input_dim, input_dim, num_layers=num_layers, batch_first=True)

    def forward(self, stroke, mask=None):
        # Apply first convolution and pooling
        stroke = stroke.permute(0, 2, 1)  # Change to (B, D, L) for Conv1d
        x = self.conv1(stroke)  # Output shape: (B, num_filters, L/stride)
        x = self.pool1(x)  # Output shape: (B, num_filters, (L/stride)/pool_stride)
        # Apply second convolution and pooling
        x = self.conv2(x)  # Output shape: (B, num_filters*2, (L/stride^2))
        x = self.pool2(x)  # Output shape: (B, num_filters*2, (L/stride^2)/pool_stride)
        # Reshape for transformer (batch_size, seq_len, feature_dim)
        x = x.permute(0, 2, 1)  # Output shape: (B, seq_output_length, num_filters*2)
        # Pass through final layer
        if self.mode == "tf":
            x = self.transformer_encoder(x)
        elif self.mode == "rnn":
            x, _ = self.rnn(x)
        else:
            assert self.mode == "neither", "Ensure that mode is either 'rnn', 'tf' or 'neither'."
        return x  # Shape: (B, seq_output_length, num_filters*2)


multi_encoder = MultiStrokeEncoder()

In [163]:
o1 = multi_encoder(out)
print(o1.shape)

torch.Size([16, 64, 128])


tensor(False)

In [139]:
mask = c_mask
B = 16
L = 32
N  = 32
if mask is None:
    lengths = torch.tensor([L] * B * N)
else:
    if len(mask.shape) == 4:
        mask = mask.sum(dim=-1)/2                    # (B x N x L)
    stroke_mask = mask.view(B * N, L)                # (B*N, L)
    lengths = stroke_mask.sum(dim=-1).cpu().long()   # (B*N,)
        
    # Identify valid sequences (non-zero lengths)
    valid_indices = lengths > 0

In [146]:
out[:,30]

tensor([[-0.1024,  0.0372,  0.0032,  ...,  0.0843, -0.0519, -0.0910],
        [-0.1024,  0.0372,  0.0032,  ...,  0.0843, -0.0519, -0.0910],
        [-0.1024,  0.0372,  0.0032,  ...,  0.0843, -0.0519, -0.0909],
        ...,
        [-0.1024,  0.0372,  0.0032,  ...,  0.0843, -0.0519, -0.0910],
        [-0.1024,  0.0372,  0.0032,  ...,  0.0843, -0.0519, -0.0910],
        [-0.1024,  0.0372,  0.0032,  ...,  0.0843, -0.0519, -0.0910]],
       grad_fn=<SelectBackward0>)

In [117]:
decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
memory = torch.rand(32,30, 512)
tgt = torch.rand(32, 20, 512)
out3 = decoder_layer(tgt, memory)
out3.shape

torch.Size([32, 20, 512])

tensor([[[ 0.4902,  0.2627],
         [ 0.9922,  0.2588],
         [-1.0000, -1.0000],
         ...,
         [-1.0000, -1.0000],
         [-1.0000, -1.0000],
         [-1.0000, -1.0000]],

        [[ 0.2510,  0.7843],
         [ 0.2431,  0.3294],
         [-1.0000, -1.0000],
         ...,
         [-1.0000, -1.0000],
         [-1.0000, -1.0000],
         [-1.0000, -1.0000]],

        [[ 0.0235,  0.3176],
         [ 0.0471,  0.3020],
         [ 0.0745,  0.3020],
         ...,
         [-1.0000, -1.0000],
         [-1.0000, -1.0000],
         [-1.0000, -1.0000]],

        ...,

        [[-1.0000, -1.0000],
         [-1.0000, -1.0000],
         [-1.0000, -1.0000],
         ...,
         [-1.0000, -1.0000],
         [-1.0000, -1.0000],
         [-1.0000, -1.0000]],

        [[ 0.5569,  0.2431],
         [ 0.0627,  0.2392],
         [ 0.1725,  0.2196],
         ...,
         [-1.0000, -1.0000],
         [-1.0000, -1.0000],
         [-1.0000, -1.0000]],

        [[ 0.2275,  0.9137],
       