# batch & time & channel
In the context of transformer models, the terms "batch," "time," and "channel" refer to different dimensions of the input data. Here's a breakdown of each term:
- **Batch**: The number of sentences (or sequences) processed at the same time.
- **Time**: The number of tokens in each sentence (or the length of the sequence).
- **Channel**: The size of the embedding for each token.

So, an input tensor to a transformer model might have the shape `[batch_size, sequence_length, embedding_size]`. 
For example, if you have a batch of **32 sentences**, each with **50 tokens**, and **each token is represented by a 512-dimensional embedding**, the input tensor shape would be `[32, 50, 512]`.

In [1]:
# Lyrics of "Never Gonna Give You Up"
lyrics = """Never gonna give you up
Never gonna let you down
Never gonna run around and desert you
Never gonna make you cry
Never gonna say goodbye
Never gonna tell a lie and hurt you"""

# Tokenizing the lyrics
tokens = lyrics.split()
print("Tokens:\n", tokens)

# Create a mapping from tokens to integers
token_to_int = {token: idx for idx, token in enumerate(set(tokens), 1)}
int_to_token = {idx: token for token, idx in token_to_int.items()}
int_to_token[0] = "<pad>"
print("Token to Integer Mapping:\n", token_to_int)
print("Integer to Token Mapping:\n", int_to_token)

# Convert tokens to integers
encoded_lyrics = [token_to_int[token] for token in tokens]

# Display the token to integer mapping and encoded lyrics
print("Encoded Lyrics:\n", encoded_lyrics)


Tokens:
 ['Never', 'gonna', 'give', 'you', 'up', 'Never', 'gonna', 'let', 'you', 'down', 'Never', 'gonna', 'run', 'around', 'and', 'desert', 'you', 'Never', 'gonna', 'make', 'you', 'cry', 'Never', 'gonna', 'say', 'goodbye', 'Never', 'gonna', 'tell', 'a', 'lie', 'and', 'hurt', 'you']
Token to Integer Mapping:
 {'run': 1, 'and': 2, 'goodbye': 3, 'up': 4, 'let': 5, 'tell': 6, 'a': 7, 'around': 8, 'gonna': 9, 'cry': 10, 'hurt': 11, 'lie': 12, 'desert': 13, 'say': 14, 'down': 15, 'make': 16, 'give': 17, 'you': 18, 'Never': 19}
Integer to Token Mapping:
 {1: 'run', 2: 'and', 3: 'goodbye', 4: 'up', 5: 'let', 6: 'tell', 7: 'a', 8: 'around', 9: 'gonna', 10: 'cry', 11: 'hurt', 12: 'lie', 13: 'desert', 14: 'say', 15: 'down', 16: 'make', 17: 'give', 18: 'you', 19: 'Never', 0: '<pad>'}
Encoded Lyrics:
 [19, 9, 17, 18, 4, 19, 9, 5, 18, 15, 19, 9, 1, 8, 2, 13, 18, 19, 9, 16, 18, 10, 19, 9, 14, 3, 19, 9, 6, 7, 12, 2, 11, 18]


In [2]:
import torch
import torch.nn as nn

In [3]:
# Define a simple transformer model
class SimpleTransformer(nn.Module):
    def __init__(self, embedding_size, num_heads, num_layers, vocab_size):
        super(SimpleTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.transformer = nn.Transformer(
            d_model=embedding_size,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            batch_first=True
        )
        self.linear = nn.Linear(embedding_size, vocab_size)
    
    def forward(self, src, tgt):
        src_emb = self.embedding(src)
        tgt_emb = self.embedding(tgt)
        transformer_output = self.transformer(src_emb, tgt_emb)
        output = self.linear(transformer_output)
        return output

    def generate(self, src, max_length):
        src_emb = self.embedding(src)
        memory = self.transformer.encoder(src_emb)
        ys = torch.zeros(1, 1).type_as(src.data)

        for i in range(max_length-1):
            tgt_emb = self.embedding(ys)
            out = self.transformer.decoder(tgt_emb, memory)
            out = self.linear(out)
            prob = out[:, -1, :].squeeze().softmax(dim=-1)
            next_word = torch.argmax(prob).item()
            ys = torch.cat([ys, torch.tensor([[next_word]]).type_as(src.data)], dim=1)
            if next_word == token_to_int.get('<eos>', 0):  # Assuming <eos> token for end of sentence
                break
        return ys


In [4]:
# Parameters
batch_size = 1  # For simplicity, using a batch size of 1
sequence_length = len(encoded_lyrics)
embedding_size = 512

# Simulate embedding by creating a tensor with random values
# Normally, you would use actual embeddings here
input_data = (torch.tensor(encoded_lyrics) # 1 x 34
    .view(batch_size, sequence_length, 1) # 1 x 34 x 1
    .float()
    .repeat(1, 1, embedding_size)) # 1 x 34 x 512
print("Input data shape:", input_data.shape)
print("Input data:\n", input_data)

# Permute to match transformer input shape [sequence_length, batch_size, embedding_size]
# permute() returns a view of the input tensor with dimensions permuted(變更/交換).
src = input_data.permute(1, 0, 2) # 1 x 34 x 512 -> 34 x 1 x 512
tgt = input_data.permute(1, 0, 2)

# Print shapes to verify
print("Source shape:", src.shape)
print("Target shape:", tgt.shape)


Input data shape: torch.Size([1, 34, 512])
Input data:
 tensor([[[19., 19., 19.,  ..., 19., 19., 19.],
         [ 9.,  9.,  9.,  ...,  9.,  9.,  9.],
         [17., 17., 17.,  ..., 17., 17., 17.],
         ...,
         [ 2.,  2.,  2.,  ...,  2.,  2.,  2.],
         [11., 11., 11.,  ..., 11., 11., 11.],
         [18., 18., 18.,  ..., 18., 18., 18.]]])
Source shape: torch.Size([34, 1, 512])
Target shape: torch.Size([34, 1, 512])


In [5]:
# Parameters for the transformer
num_heads = 8
num_layers = 6
vocab_size = len(token_to_int) + 1  # Plus one for padding/indexing

# Create the model
model = SimpleTransformer(embedding_size, num_heads, num_layers, vocab_size)

# Forward pass
src = torch.tensor(encoded_lyrics).unsqueeze(0)  # Add batch dimension
tgt = torch.tensor(encoded_lyrics).unsqueeze(0)  # Add batch dimension
output = model(src, tgt)

print("Output shape:", output.shape)

# Generate new sequence
start_token = torch.tensor([[token_to_int["Never"]]])
generated_sequence = model.generate(start_token, max_length=20)
str_list = list(map(lambda s: int_to_token.get(s)+" ", generated_sequence[0].tolist()))
generated_txt = "".join(str_list)
print(generated_txt)

Output shape: torch.Size([1, 34, 20])
<pad> gonna gonna lie gonna let Never let Never let gonna let gonna gonna gonna down let gonna gonna say 
