In [1]:
import torch
def generate_square_mask(dim_trg: int, dim_src: int, mask_type: str) -> torch.Tensor:
    mask = torch.ones(dim_trg, dim_trg)* float('-inf')
    if mask_type == "src":
        mask = torch.triu(mask, diagonal=1)
    elif mask_type == "tgt":
        mask = torch.triu(mask, diagonal=1)
        #mask = torch.tril(mask)
    elif mask_type == "memory":
        mask = torch.ones(dim_trg, dim_src)* float('-inf')
        mask = torch.triu(mask, diagonal=1)
    return mask


def generate_square_mask(dim_trg: int, dim_src: int, mask_type: str) -> torch.Tensor:
    """
    Generate a square mask for transformer attention mechanisms.
    
    Args:
        dim_trg (int): Target sequence length.
        dim_src (int): Source sequence length.
        mask_type (str): Type of mask to generate. Can be "src", "tgt", or "memory".
    
    Returns:
        torch.Tensor: A mask tensor with `-inf` values to block specific positions.
    """

    # Initialize a square matrix filled with -inf (default to a fully masked state)
    mask = torch.ones(dim_trg, dim_trg) * float('-inf')

    if mask_type == "src":
        # Source mask (self-attention in the encoder)
        # Creates an upper triangular matrix with -inf above the diagonal
        # This allows each token to attend to itself and previous tokens
        mask = torch.triu(mask, diagonal=1)

    elif mask_type == "tgt":
        # Target mask (self-attention in the decoder)
        # Prevents the decoder from attending to future tokens (causal mask)
        mask = torch.triu(mask, diagonal=1)

    elif mask_type == "memory":
        # Memory mask (cross-attention between encoder and decoder)
        # Controls which encoder outputs the decoder can attend to
        mask = torch.ones(dim_trg, dim_src) * float('-inf')
        mask = torch.triu(mask, diagonal=1)  # Prevents attending to future positions

    return mask


In [2]:
generate_square_mask(dim_trg = 12 ,dim_src = 8, mask_type="tgt")

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [10]:
import torch.nn as nn

dropout_layer = nn.Dropout(p=0.5)
dropout_layers = nn.Dropout(p=0.5)
dropout_layersss = nn.Dropout(p=0.3)

torch.manual_seed(42)  # Set seed
x = torch.ones(5)
print(dropout_layer(x))  # First output
print(dropout_layersss(x))
print(dropout_layers(x))


torch.manual_seed(42)  # Reset seed
print(dropout_layer(x))  # Second output (should match the first)
print(dropout_layers(x))
print(dropout_layersss(x))


tensor([0., 0., 2., 2., 2.])
tensor([1.4286, 0.0000, 1.4286, 1.4286, 1.4286])
tensor([0., 2., 2., 2., 2.])
tensor([0., 0., 2., 2., 2.])
tensor([0., 0., 0., 2., 0.])
tensor([0.0000, 1.4286, 1.4286, 1.4286, 1.4286])
