# Example Without Buffers

In [1]:
import torch,torch.nn as nn
class CausalAttentionWithoutBuffers(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,qkv_bias=False):
        super().__init__()
        self.d_out=d_out
        self.W_query=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_value=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.dropout=nn.Dropout(dropout)
        self.mask=torch.triu(torch.ones(context_length,context_length),diagonal=1)
    def forward(self,x):
        b,num_tokens,d_in=x.shape
        keys=self.W_key(x)
        queries=self.W_query(x)
        values=self.W_value(x)
        attn_scores=queries@keys.transpose(1,2)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens,
                                                  :num_tokens],-torch.inf)
        attn_weights=torch.softmax(attn_scores/keys.shape[-1]**.5,dim=-1)
        attn_weights=self.dropout(attn_weights)
        context_vec=attn_weights@values
        return context_vec
inputs=torch.tensor([[.43,.15,.89],
                     [.55,.87,.66],
                     [.57,.85,.64],
                     [.22,.58,.33],
                     [.77,.25,.1],
                     [.05,.8,.55]])
batch=torch.stack((inputs,inputs),dim=0)
context_length=batch.shape[1]
d_in=inputs.shape[1]
d_out=2
ca_without_buffer=CausalAttentionWithoutBuffers(d_in,d_out,context_length,0)
with torch.no_grad():
    context_vecs=ca_without_buffer(batch)
context_vecs

tensor([[[-0.1326, -0.6088],
         [-0.1689, -0.7603],
         [-0.1853, -0.8077],
         [-0.1602, -0.7431],
         [-0.2002, -0.6679],
         [-0.1580, -0.6887]],

        [[-0.1326, -0.6088],
         [-0.1689, -0.7603],
         [-0.1853, -0.8077],
         [-0.1602, -0.7431],
         [-0.2002, -0.6679],
         [-0.1580, -0.6887]]])

In [2]:
has_mps=torch.backends.mps.is_available()
print(f'Machine has GPU: {has_mps}')

Machine has GPU: True


In [3]:
device=torch.device('mps')
print(f'Using device: {device}')

Using device: mps


In [4]:
batch=batch.to(device)
ca_without_buffer=ca_without_buffer.to(device)
print(f'W_query.device: {ca_without_buffer.W_query.weight.device}\nmask.device: {ca_without_buffer.mask.device}\n{type(ca_without_buffer.mask)}')

W_query.device: mps:0
mask.device: cpu
<class 'torch.Tensor'>


In [5]:
ca_without_buffer.mask=ca_without_buffer.mask.to(device)
print(f'mask.device: {ca_without_buffer.mask.device}')

mask.device: mps:0


In [6]:
with torch.no_grad():
    context_vecs=ca_without_buffer(batch)
context_vecs

tensor([[[-0.1326, -0.6088],
         [-0.1689, -0.7603],
         [-0.1853, -0.8077],
         [-0.1602, -0.7431],
         [-0.2002, -0.6679],
         [-0.1580, -0.6887]],

        [[-0.1326, -0.6088],
         [-0.1689, -0.7603],
         [-0.1853, -0.8077],
         [-0.1602, -0.7431],
         [-0.2002, -0.6679],
         [-0.1580, -0.6887]]], device='mps:0')

# Example With Buffers
## Buffers, state_dict

In [7]:
class CausalAttentionWithBuffer(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,qkv_bias=False):
        super().__init__()
        self.d_out=d_out
        self.W_query=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_value=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.dropout=nn.Dropout(dropout)
        self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))
    def forward(self,x):
        b,num_tokens,d_in=x.shape
        keys=self.W_key(x)
        queries=self.W_query(x)
        values=self.W_value(x)
        attn_scores=queries@keys.transpose(1,2)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens,
                                                  :num_tokens],-torch.inf)
        attn_weights=torch.softmax(attn_scores/keys.shape[-1]**.5,dim=-1)
        attn_weights=self.dropout(attn_weights)
        context_vec=attn_weights@values
        return context_vec
ca_with_buffer=CausalAttentionWithBuffer(d_in,d_out,context_length,0)
ca_with_buffer.to(device)
ca_without_buffer.state_dict()

OrderedDict([('W_query.weight',
              tensor([[-0.0850, -0.0829, -0.4284],
                      [ 0.3573, -0.4932, -0.4831]], device='mps:0')),
             ('W_key.weight',
              tensor([[ 0.4202,  0.2805, -0.5325],
                      [-0.2468, -0.3244, -0.5748]], device='mps:0')),
             ('W_value.weight',
              tensor([[-0.4955,  0.0073,  0.0892],
                      [-0.2776, -0.5334, -0.4600]], device='mps:0'))])

In [8]:
ca_with_buffer.state_dict()

OrderedDict([('mask',
              tensor([[0., 1., 1., 1., 1., 1.],
                      [0., 0., 1., 1., 1., 1.],
                      [0., 0., 0., 1., 1., 1.],
                      [0., 0., 0., 0., 1., 1.],
                      [0., 0., 0., 0., 0., 1.],
                      [0., 0., 0., 0., 0., 0.]], device='mps:0')),
             ('W_query.weight',
              tensor([[-0.1735, -0.2387, -0.0884],
                      [ 0.1110, -0.1784, -0.2616]], device='mps:0')),
             ('W_key.weight',
              tensor([[-0.0903,  0.4589,  0.3723],
                      [ 0.5692, -0.0475, -0.1776]], device='mps:0')),
             ('W_value.weight',
              tensor([[-0.0305, -0.2526,  0.1203],
                      [ 0.0751,  0.4521,  0.3481]], device='mps:0'))])

In [9]:
ca_with_buffer.mask[ca_with_buffer.mask==1]=2
ca_with_buffer.mask

tensor([[0., 2., 2., 2., 2., 2.],
        [0., 0., 2., 2., 2., 2.],
        [0., 0., 0., 2., 2., 2.],
        [0., 0., 0., 0., 2., 2.],
        [0., 0., 0., 0., 0., 2.],
        [0., 0., 0., 0., 0., 0.]], device='mps:0')

In [10]:
torch.save(ca_with_buffer.state_dict(),'model.pth')
new_ca_with_buffer=CausalAttentionWithBuffer(d_in,d_out,context_length,0)
new_ca_with_buffer.load_state_dict(torch.load('model.pth'))
new_ca_with_buffer.mask

tensor([[0., 2., 2., 2., 2., 2.],
        [0., 0., 2., 2., 2., 2.],
        [0., 0., 0., 2., 2., 2.],
        [0., 0., 0., 0., 2., 2.],
        [0., 0., 0., 0., 0., 2.],
        [0., 0., 0., 0., 0., 0.]])

In [11]:
ca_without_buffer.mask[ca_without_buffer.mask==1]=2
torch.save(ca_without_buffer.state_dict(),'model.pth')
new_ca_without_buffer=CausalAttentionWithoutBuffers(d_in,d_out,context_length,0)
new_ca_without_buffer.load_state_dict(torch.load('model.pth'))
new_ca_without_buffer.mask

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])