We first define the multiHead attention 

In [198]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import math
import numpy as np
import seaborn as sns
import matplotlib.pylab as plt
import pandas as pd

In [199]:
def scaled_dot_product(q, k, v, mask=None):
    """ Computes the Scaled Dot-Product Attention

    Args:
        q (torch.FloatTensor):  Query Tensor   (... x T_q x d_q)
        k (torch.FloatTensor):  Key Tensor     (... x T_k x d_k)
        v (torch.FloatTensor):  Value Tensor   (... x T_v x d_v)
        mask (torch.BoolTensor): Attention mask (... x T_q x T_k)

    Returns:
        torch.FloatTensor: Result of the SDPA  (... x T_q x d_v)
        torch.FloatTensor: Attention map       (... x T_q x T_k)

    """
    assert q.size(-1) == k.size(-1), "Query and Key dimensions must coincide"

    # TODO: Matrix multiplication of the queries and the keys (use torch.matmul)
    #attn_logits =
    attn_logits = torch.matmul(q, k.transpose(-2, -1))

    # TODO: Scale attn_logits (see the SDPA formula, d_k is the last dim of k)
    #attn_logits = 
    attn_logits = attn_logits/torch.sqrt(torch.tensor(k.size(-1), dtype=torch.float32))

    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask, -float("inf"))

    # TODO: Compute the attention weights (see the SDPA formula, use dim=-1)
    #attention =
    attention = torch.softmax(attn_logits, dim=-1)

    output = torch.matmul(attention, v)

    return output, attention

In [200]:
class MultiheadAttention(nn.Module):

    def __init__(self, embed_dim, num_heads):
        super(MultiheadAttention, self).__init__()
        assert embed_dim % num_heads == 0, \
            "Embedding dimension must be multiple of the number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.proj_q = nn.Linear(embed_dim, embed_dim)
        self.proj_k = nn.Linear(embed_dim, embed_dim)
        self.proj_v = nn.Linear(embed_dim, embed_dim)
        self.proj_o = nn.Linear(embed_dim, embed_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        # Original Transformer initialization
        nn.init.xavier_uniform_(self.proj_q.weight)
        nn.init.xavier_uniform_(self.proj_k.weight)
        nn.init.xavier_uniform_(self.proj_v.weight)
        nn.init.xavier_uniform_(self.proj_o.weight)
        self.proj_q.bias.data.fill_(0)
        self.proj_k.bias.data.fill_(0)
        self.proj_v.bias.data.fill_(0)
        self.proj_o.bias.data.fill_(0)

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(1)

        q = self.proj_q(q)
        k = self.proj_k(k)
        v = self.proj_v(v)

        # TODO: Split the tensors into multiple heads
        #  T x B x embed_dim -> T x B x num_heads x head_dim
        q = q.reshape(q.size(0), batch_size, self.num_heads, self.head_dim)
        k = k.reshape(k.size(0), batch_size, self.num_heads, self.head_dim)
        v = v.reshape(v.size(0), batch_size, self.num_heads, self.head_dim)

        # The last two dimensions must be sequence length and the head dimension,
        # to make it work with the scaled dot-product function.
        # TODO: Rearrange the dimensions
        # T x B x num_heads x head_dim -> B x num_heads x T x head_dim
        q = q.permute(1, 2, 0, 3)
        k = k.permute(1, 2, 0, 3)
        v = v.permute(1, 2, 0, 3)

        # Apply the same mask to all the heads
        if mask is not None:
            mask = mask.unsqueeze(1)

        # TODO: Call the scaled dot-product function (remember to pass the mask!)
        output_heads, attn_w = scaled_dot_product(q, k, v, mask)

        # B x num_heads x T x head_dim -> T x B x num_heads x head_dim
        output_heads = output_heads.permute(2, 0, 1, 3)

        # T x B x num_heads x head_dim -> T x B x embed_dim
        output_cat = output_heads.reshape(-1, batch_size, self.embed_dim)
        output = self.proj_o(output_cat)

        return output, attn_w

In [201]:
class PositionalEncoding(nn.Module):

    def __init__(self, embed_dim, max_len=5000):
        """
        Args:
            embed_dim (int): Embedding dimensionality
            max_len (int): Maximum length of a sequence to expect
        """
        super(PositionalEncoding, self).__init__()

        # Create matrix of (T x embed_dim) representing the positional encoding
        # for max_len inputs
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)

        self.register_buffer('pe', pe, persistent=False)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return x

In [202]:
class TransformerEncoderLayer(nn.Module):

    def __init__(self, embed_dim, ffn_dim, num_heads, dropout=0.0):
        """
        Args:
            embed_dim (int): Embedding dimensionality (input, output & self-attention)
            ffn_dim (int): Inner dimensionality in the FFN
            num_heads (int): Number of heads of the multi-head attention block
            dropout (float): Dropout probability
        """
        super(TransformerEncoderLayer, self).__init__()

        self.self_attn = MultiheadAttention(embed_dim, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ffn_dim),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(ffn_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None, return_att=False):
        src_len, batch_size, _ = x.shape
        if mask is None:
            mask = torch.zeros(x.shape[1], x.shape[0]).bool().to(x.device)

        selfattn_mask = mask.unsqueeze(-2)

        # TODO: Self-Attention block
        selfattn_out, selfattn_w = self.self_attn(x, x, x, selfattn_mask)
        selfattn_out = self.dropout(selfattn_out)

        # TODO: Add + normalize block (1)
        x = self.norm1(x + selfattn_out)

        # TODO: FFN block
        ffn_out = self.ffn(x)
        ffn_out = self.dropout(ffn_out)

        # TODO: Add + normalize block (2)
        x = self.norm2(x + ffn_out)

        if return_att:
            return x, selfattn_w
        else:
            return x

In [203]:
class TransformerEncoder(nn.Module):

    def __init__(self, num_layers, embed_dim, ffn_dim, num_heads, dropout=0.0):
        super(TransformerEncoder, self).__init__()

        # Create an embedding table (T x B -> T x B x embed_dim)
        # self.embedding = nn.Embedding(, embed_dim)

        # Create the positional encoding with the class defined before
        self.pos_enc = PositionalEncoding(embed_dim)

        self.layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, ffn_dim, num_heads, dropout)
            for _ in range(num_layers)
        ])

    def forward(self, x, mask=None, return_att=False):
        #x = self.embedding(x)
        x = self.pos_enc(x)

        selfattn_ws = []
        for l in self.layers:
            if return_att:
                x, selfattn_w = l(x, mask=mask, return_att=True)
                selfattn_ws.append(selfattn_w)
            else:
                x = l(x, mask=mask, return_att=False)

        if return_att:
            selfattn_ws = torch.stack(selfattn_ws, dim=1)
            return x, selfattn_ws
        else:
            return x

In [204]:
class TransformerDecoderLayer(nn.Module):

    def __init__(self, embed_dim, ffn_dim, num_heads, dropout=0.0):
        """
        Args:
            embed_dim (int): Embedding dimensionality (input, output & self-attention)
            ffn_dim (int): Inner dimensionality in the FFN
            num_heads (int): Number of heads of the multi-head attention block
            dropout (float): Dropout probability
        """
        super(TransformerDecoderLayer, self).__init__()

        self.self_attn = MultiheadAttention(embed_dim, num_heads)
        self.encdec_attn = MultiheadAttention(embed_dim, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ffn_dim),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(ffn_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, memory, mask=None, memory_mask=None, return_att=False):
        tgt_len, batch_size, _ = x.shape
        src_len, _, _ = memory.shape
        if mask is None:
            mask = torch.zeros(x.shape[1], x.shape[0])
            mask = mask.bool().to(x.device)
        if memory_mask is None:
            memory_mask = torch.zeros(memory.shape[1], memory.shape[0])
            memory_mask = memory_mask.bool().to(memory.device)


        subsequent_mask = torch.triu(torch.ones(batch_size, tgt_len, tgt_len), 1)
        subsequent_mask = subsequent_mask.bool().to(mask.device)
        selfattn_mask = subsequent_mask + mask.unsqueeze(-2)

        attn_mask = memory_mask.unsqueeze(-2)

        # TODO: Self-Attention block
        selfattn_out, selfattn_w = self.self_attn(x, x, x, selfattn_mask)
        selfattn_out = self.dropout(selfattn_out)

        # TODO: Add + normalize block (1)
        x = self.norm1(x + selfattn_out)

        # TODO: Encoder-Decoder Attention block
        attn_out, attn_w = self.encdec_attn(x, memory, memory, attn_mask)
        attn_out = self.dropout(attn_out)

        # TODO: Add + normalize block (2)
        x = self.norm2(x + attn_out)

        # TODO: FFN block
        ffn_out = self.ffn(x)
        ffn_out = self.dropout(ffn_out)

        # TODO: Add + normalize block (3)
        x = self.norm3(x + ffn_out)

        if return_att:
            return x, selfattn_w, attn_w
        else:
            return x

In [205]:
class TransformerDecoder(nn.Module):

    def __init__(self, num_layers, embed_dim, ffn_dim, num_heads, dropout=0.0):
        super(TransformerDecoder, self).__init__()

        # Create an embedding table (T x B -> T x B x embed_dim)
        # self.embedding = nn.Embedding(vocab_size, embed_dim)

        # Create the positional encoding with the class defined before
        self.pos_enc = PositionalEncoding(embed_dim)

        self.layers = nn.ModuleList([
            TransformerDecoderLayer(embed_dim, ffn_dim, num_heads, dropout)
            for _ in range(num_layers)
        ])

        # Add a projection layer (T x B x embed_dim -> T x B x vocab_size)
        # self.proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, memory, mask=None, memory_mask=None, return_att=False):
        #x = self.embedding(x)
        x = self.pos_enc(x)

        selfattn_ws = []
        attn_ws = []
        for l in self.layers:
            if return_att:
                x, selfattn_w, attn_w = l(
                    x, memory, mask=mask, memory_mask=memory_mask, return_att=True
                )
                selfattn_ws.append(selfattn_w)
                attn_ws.append(attn_w)
            else:
                x = l(
                    x, memory, mask=mask, memory_mask=memory_mask, return_att=False
                )

        x = self.proj(x)
        x = F.log_softmax(x, dim=-1)

        if return_att:
            selfattn_ws = torch.stack(selfattn_ws, dim=1)
            attn_ws = torch.stack(attn_ws, dim=1)
            return x, selfattn_ws, attn_ws
        else:
            return x

In [206]:
class Transformer(nn.Module):
    def __init__(self, encoder_config, decoder_config):
        super(Transformer, self).__init__()
        self.encoder = TransformerEncoder(**encoder_config)
        self.decoder = TransformerDecoder(**decoder_config)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        """ Forward method

        Method used at training time, when the target is known. The target tensor
        passed to the decoder is shifted to the right (starting with BOS
        symbol). Then, the output of the decoder starts directly with the first
        token of the sentence.
        """

        # TODO: Compute the encoder output
        encoder_out = self.encoder(src, src_mask)

        # TODO: Compute the decoder output
        decoder_out = self.decoder(
            x=tgt,
            memory=encoder_out,
            mask=tgt_mask,
            memory_mask=src_mask
        )

        return decoder_out

    def generate(self, src, src_mask=None, bos_idx=0, max_len=50):
        """ Generate method

        Method used at inference time, when the target is unknown. It
        iteratively passes to the decoder the sequence generated so far
        and appends the new token to the input again. It uses a Greedy
        decoding (argmax).
        """

        # TODO: Compute the encoder output
        encoder_out = self.encoder(src, src_mask)

        output = torch.LongTensor([bos_idx])\
                    .expand(1, encoder_out.size(1)).to(src.device)
        for i in range(max_len):
            # TODO: Get the new token
            new_token = self.decoder(
                x=output,
                memory=encoder_out,
                memory_mask=src_mask
            )[-1].argmax(-1)

            output = torch.cat([output, new_token.unsqueeze(0)], dim=0)

        return output

In [207]:
from seq2seq_numbers_dataset import generate_dataset_pytorch, Seq2SeqNumbersCollater

TRAIN_SIZE = 1000

data = pd.read_csv("./dataset_train_2024.csv", header=None)

label = data.iloc[1:, 258]

print(label)

data = data.iloc[1:, : 257]

trainData = data[:TRAIN_SIZE]
trainLabels = label[:TRAIN_SIZE]

data.head

collater = Seq2SeqNumbersCollater(
    trainData,
    trainLabels,
)


1        GFSK
2        BPSK
3        BPSK
4        BPSK
5        8PSK
         ... 
11996    BPSK
11997    8PSK
11998    QPSK
11999    8PSK
12000    QPSK
Name: 258, Length: 12000, dtype: object


In [210]:
lr = 5e-4
batch_size = 32
log_interval = 50
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

numbers_loader_train = DataLoader(
    trainData,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collater,
)

print(numbers_loader_train.dataset)

src_dict = trainData
tgt_dict = trainLabels

transformer_encoder_cfg = {
    "num_layers": 3,
    "embed_dim": 256,
    "ffn_dim": 1024,
    "num_heads": 4,
    #"vocab_size": len(src_dict),
    "dropout": 0.1,
}
transformer_decoder_cfg = {
    "num_layers": 3,
    "embed_dim": 256,
    "ffn_dim": 1024,
    "num_heads": 4,
    # "vocab_size": len(tgt_dict),
    "dropout": 0.1,
}
model = Transformer(transformer_encoder_cfg, transformer_decoder_cfg)
model.to(device)
model.train()

optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = F.nll_loss

print("Training model...")

loss_avg = 0
print(numbers_loader_train)
for i, (src, tgt) in enumerate(numbers_loader_train):
    print(src)
    src = {k: v.to(device) for k, v in src.items()}
    print(src)
    tgt = {k: v.to(device) for k, v in tgt.items()}

    optimizer.zero_grad()

    output = model(
        src['ids'],
        tgt['ids'][:-1],
        src['padding_mask'],
        tgt['padding_mask'][:, :-1],
    )

    loss = criterion(
        output.reshape(-1, output.size(-1)),
        tgt['ids'][1:].flatten()
    )
    loss.backward()
    optimizer.step()

    loss_avg += loss.item()
    if (i+1) % log_interval == 0:
        loss_avg /= log_interval
        print(f"{i+1}/{len(numbers_loader_train)}\tLoss: {loss_avg}")

        0         1         2         3         4         5         6    \
1       0.0 -0.002737 -0.003256 -0.002842 -0.003326 -0.003696 -0.002624   
2       1.0 -0.002686 -0.003358 -0.004155 -0.005550 -0.006590 -0.007223   
3       2.0 -0.002638 -0.002471 -0.002312 -0.002172 -0.002040 -0.002214   
4       3.0 -0.001875 -0.002034 -0.002197 -0.002201 -0.002347 -0.002576   
5       4.0 -0.006637 -0.006698 -0.007560 -0.007685 -0.008237 -0.007881   
...     ...       ...       ...       ...       ...       ...       ...   
996   995.0  0.005960  0.005648  0.005329  0.004734  0.004226  0.003602   
997   996.0  0.003330  0.003050  0.004500  0.005088  0.005770  0.007044   
998   997.0 -0.000866 -0.002570 -0.004254 -0.005894 -0.007288 -0.008403   
999   998.0  0.002161  0.003685  0.005081  0.006265  0.007204  0.007833   
1000  999.0  0.004353  0.004604  0.004521  0.004428  0.004170  0.004015   

           7         8         9    ...       247       248       249  \
1    -0.002620 -0.001829 -

KeyError: 535