In [1]:
import torch

In [2]:
# -------------------------------
# Tensor Equality Check Example
# -------------------------------

T1 = torch.tensor([1., 2.])
T2 = torch.tensor([1., 2.])

are_equal = torch.allclose(T1, T2)
print("Tensors are equal:", are_equal)

Tensors are equal: True


In [3]:
# -------------------------------
# Lower Triangular Matrix Example
# -------------------------------

matrix = torch.tril(torch.ones(3, 3))
print("Lower triangular matrix:\n", matrix)

Lower triangular matrix:
 tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])


In [6]:
# -------------------------------
# Masked Fill Example (Attention Mask)
# -------------------------------

import torch.nn.functional as F

T = 5  # context length
tril = torch.tril(torch.ones(T, T))

wei = torch.zeros(T, T)
wei = wei.masked_fill(tril == 0, float('-inf'))
print("Masked weights (before softmax):\n", wei)

wei = F.softmax(wei, dim=-1)
print("Weights after softmax:\n", wei)

Masked weights (before softmax):
 tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])
Weights after softmax:
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]])


In [7]:
# -------------------------------
# Register Buffer Example
# -------------------------------

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # buffer is saved and moved with the model but not trained
        self.register_buffer("mask", torch.tril(torch.ones(4, 4)))

    def forward(self, x):
        print("Inside forward, buffer mask:\n", self.mask)
        return x

m = MyModule()
print("Registered buffer:\n", m.mask)

# simulate moving to device
m.to("cpu")
print("Buffer on CPU:\n", m.mask.device)

# show up in state_dict
print("Buffer in state_dict:\n", m.state_dict()["mask"])


Registered buffer:
 tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])
Buffer on CPU:
 cpu
Buffer in state_dict:
 tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])
