In [None]:

import torch
import torch.nn as nn

class BasicTransformer(nn.Module):
    def __init__(self, args = None, hidden_dim = 16, num_layers = 2, nhead=1):
        super(BasicTransformer, self).__init__()



        if args is None:
            args = lambda: None
            args.segment_length = 500
            args.sum_up_length = 5
            args.segment_num = 7

        c_N = (args.segment_length + 1) // 2 + (args.segment_length + 1) % 2
        input_dim = c_N // args.sum_up_length

        self.sum_up_length = args.sum_up_length

        self.segment_num = args.segment_num

        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.pos_embedding = nn.Parameter(torch.randn(1, 500, hidden_dim))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dim_feedforward=hidden_dim*4,
            dropout=0.1,
            batch_first=True
        )

        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, enable_nested_tensor=False)

        self.fc_out = nn.Linear(hidden_dim, 2)



    def differential_signed_min_max(self, x): # Creates the m x m matrix of differential segments and normalizes them
        """
            b: batch_size
            c: channel_num
            m: segment_num
            n: feature_dim
        """
        b, c, m, n = x.shape
        A = torch.unsqueeze(x, dim=-2)
        B = torch.unsqueeze(x, dim=-3)
        val = A - B # (b, c, m, m, n)
        neg = val < 0
        abs_c = torch.abs(val).view(b, c, -1, n)
        max_val, _ = torch.max(abs_c, dim=-2)
        max_val = torch.unsqueeze(max_val, dim=-2)
        norm_val = abs_c / max_val
        norm_val = norm_val.view(b, c, m, m, -1)
        norm_val[neg] *= -1
        norm_val = norm_val.permute(0, 1, 4, 2, 3)
        return val.permute(0, 1, 4, 2, 3) # (b, c, n, m, m)



    def forward(self, x, p = None):

        b, c, _ = x.shape

        N = 2 * self.segment_num + 1

        x = x.view(x.shape[0],c, N, -1)

        x = x.view(x.shape[0],c, N, -1)
        b = x.shape[0]

        x = torch.log(torch.abs(torch.fft.rfft(x, dim=-1)) + 1)

        c_N = x.shape[-1]
        avai_c_N = c_N // self.sum_up_length * self.sum_up_length

        x = x[:,:, :, :avai_c_N].view(b, c, N, -1, self.sum_up_length).sum(dim=-1)  # (b, 15, 50)


        x = self.differential_signed_min_max(x) # (b, ic, N, N)


        x = x.permute(0,1,3,4,2).contiguous().view(b,c, N * N, -1)

        x = self.input_proj(x)

        x = x.view(b * c, x.size(2), -1)

        x = x + self.pos_embedding[:, :x.size(1), :]

        encoded = self.transformer_encoder(x)

        pooled = encoded.mean(dim=1)

        out = self.fc_out(pooled)

        return out
    


model = BasicTransformer()
input_tensor = torch.randn(2, 3, 15 * 500) # Batch, Channel, Length = (2 * segment_num + 1) * segment_length
output = model(input_tensor)

print(output.shape)  # Expected output shape: (6, 2)
