#Atención desnuda.

El mecanismo de la atención paso a paso

```
# Tiene formato de código
```



Copia de: https://github.com/jostmey/NakedAttention/tree/8a808e1344989a00082b33b3a7ab38410b599747

In [5]:
!pip install torchmetrics

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchmetrics
  Downloading torchmetrics-0.11.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torchmetrics
Successfully installed torchmetrics-0.11.4


In [6]:
import torchvision
import torch
import torchmetrics

Bloque de dataplumbing:

In [7]:

##########################################################################################
# Carga datos del MNIST
##########################################################################################

# Load training, validation, and test data from the MNIST dataset
#
def load_mnist(seed=None, device=torch.device('cpu')):

  # Random number generator
  # 
  generator = torch.Generator(device=device)
  if seed is not None:
    generator.manual_seed(seed)

  # Load MNIST dataset
  #
  samples_train = torchvision.datasets.MNIST('./', train=True, download=True)
  samples_test = torchvision.datasets.MNIST('./', train=False, download=True)

  # Format features and labels
  #
  xs = samples_train.data.to(device)
  num = xs.shape[0]
  xs = xs.reshape([ num, 28**2, 1 ])
  xs = xs.type(torch.float32)
  ys = samples_train.train_labels.to(device)

  xs_test = samples_test.data.to(device)
  num_test = xs_test.shape[0]
  xs_test = xs_test.reshape([ num_test, 28**2, 1 ])
  xs_test = xs_test.type(torch.float32)
  ys_test = samples_test.test_labels.to(device)

  # Split into training and validation samples
  #
  num_train = int(num*5/6)
  num_val = num-num_train

  js = torch.randperm(num, generator=generator)
  js_train = js[:num_train]
  js_val = js[num_train:]

  xs_train = xs[js_train]
  ys_train = ys[js_train]

  xs_val = xs[js_val]
  ys_val = ys[js_val]

  # Normalizing features
  #
  mean = torch.mean(xs_train, axis=0, keepdim=True)
  variance = torch.var(xs_train, axis=0, keepdim=True)

  xs_train = (xs_train-mean)/torch.std(variance+1.0E-8)
  xs_val = (xs_val-mean)/torch.std(variance+1.0E-8)
  xs_test = (xs_test-mean)/torch.std(variance+1.0E-8)

  return xs_train, ys_train, xs_val, ys_val, xs_test, ys_test

In [None]:
##########################################################################################
# Model
##########################################################################################

class SelfAttentionModel(torch.nn.Module):
  def __init__(self, num_steps, num_channels, num_outputs, **kwargs):
    super().__init__(**kwargs)

    # Initialize components for self-attention
    #
    self.K = torch.nn.Parameter((2.0*torch.rand(num_channels, num_channels)-1.0)/num_channels**0.5) # Randomly intialize each weight uniformly from [ -1/num_channels**0.5, 1/num_channels**0.5 ]
    self.Q = torch.nn.Parameter((2.0*torch.rand(num_channels, num_channels)-1.0)/num_channels**0.5)
    self.V = torch.nn.Parameter((2.0*torch.rand(num_channels, num_channels)-1.0)/num_channels**0.5)

    self.softmax = torch.nn.Softmax(dim=1)

    # Initialize output layer
    #
    self.out = torch.nn.Linear(num_steps*num_channels, num_outputs)

  def forward(self, x):

    batch_size, num_steps, num_channels = x.shape

    # Run self attention
    #
    y = []
    for i in range(batch_size): # Process one sample at a time

      x_i = x[i,:,:] # x_i has shape of [ num_steps, num_channels ]

      x_k_i = torch.matmul(x_i, self.K) # x_k_i has shape of [ num_steps, num_channels ]
      x_q_i = torch.matmul(x_i, self.Q) # x_q_i has shape of [ num_steps, num_channels ]
      x_v_i = torch.matmul(x_i, self.V) # x_v_i has shape of [ num_steps, num_channels ]

      w_i = self.softmax(torch.matmul(x_q_i, x_k_i.T)/num_channels**0.5) # w_i has shape of [ num_steps, num_steps ]
      y_i = torch.matmul(w_i, x_v_i) # y_i has shape of [ num_steps, num_channels ]

      y.append(y_i)
    y = torch.stack(y, axis=0) # y has shape of [ batch_size, num_steps, num_channels ]

    # Flatten output
    #
    y_flat = y.reshape([ batch_size, num_steps*num_channels ]) # y_flat has shape of [ batch_size, num_steps*num_channels ]

    # Run output layer
    #
    l = self.out(y_flat) # l has shape of [ batch_size, num_outputs ]

    return l

##########################################################################################
# Instantiate model, performance metrics, and optimizer.
##########################################################################################

model = SelfAttentionModel(num_steps=28**2, num_channels=1, num_outputs=10)
probability = torch.nn.Softmax(dim=1)

loss = torch.nn.CrossEntropyLoss()
accuracy = torchmetrics.classification.MulticlassAccuracy(num_classes=10)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

##########################################################################################
# Dataset and data sampler
##########################################################################################

xs_train, ys_train, xs_val, ys_val, xs_test, ys_test = load_mnist(seed=46525)

dataset_train = torch.utils.data.TensorDataset(xs_train, ys_train)
sampler_train = torch.utils.data.RandomSampler(dataset_train, replacement=True)
loader_train = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=16, sampler=sampler_train, drop_last=True)

##########################################################################################
# Model
##########################################################################################

i_better = -1
e_better = 1.0e8
a_better = 0.0
state_better = {}

# Loop over the dataset for many epochs
#
for i in range(128):

  # Train the model
  #
  model.train()
  e_train = 0.0
  a_train = 0.0
  for xs_batch, ys_batch in iter(loader_train): # Must use `iter` or `enumerate` for efficiency
    ls_batch = model(xs_batch)
    ps_batch = probability(ls_batch) # Model outputs logits that we must convert to probabilities
    e_batch = loss(ls_batch, ys_batch) # CrossEntropyLoss requires logits
    a_batch = accuracy(ps_batch, ys_batch)
    optimizer.zero_grad()
    e_batch.backward()
    optimizer.step()
    e_train += e_batch.detach()/len(loader_train) # Accumulate average loss for this epoch
    a_train += a_batch.detach()/len(loader_train) # Accumulate average accuracy for this epoch

  # Assess performance on validation data
  #
  model.eval()
  with torch.no_grad():
    ls_val = model(xs_val)
    ps_val = probability(ls_val) # Model outputs logits that we must convert to probabilities
    e_val = loss(ls_val, ys_val) # CrossEntropyLoss requires logits
    a_val = accuracy(ps_val, ys_val)
    if e_val < e_better: # Early stopping check
      i_better = i
      e_better = e_val
      a_better = a_val
      state_better = model.state_dict()

  # Print report
  #
  print(
    'i: '+str(i),
    'e_train: {:.5f}'.format(float(e_train)/0.693)+' bits',
    'a_train: {:.1f}'.format(100.0*float(a_train))+' %',
    'e_val: {:.5f}'.format(float(e_val)/0.693)+' bits',
    'a_val: {:.1f}'.format(100.0*float(a_val))+' %',
    sep='\t', flush=True
  )

model.eval()
model.load_state_dict(state_better)
with torch.no_grad():
  ls_test = model(xs_test)
  ps_test = probability(ls_test) # Model outputs logits that we must convert to probabilities
  e_test = loss(ls_test, ys_test) # CrossEntropyLoss requires logits
  a_test = accuracy(ps_test, ys_test)

print(
  'e_test: {:.5f}'.format(float(e_test)/0.693)+' bits',
  'a_test: {:.1f}'.format(100.0*float(a_test))+' %',
  sep='\t', flush=True
)



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 102760359.88it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 17209219.18it/s]


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 26323389.20it/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 13039376.30it/s]


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw





i: 0	e_train: 1.89826 bits	a_train: 44.2 %	e_val: 0.56164 bits	a_val: 88.9 %
i: 1	e_train: 0.51991 bits	a_train: 73.6 %	e_val: 0.52756 bits	a_val: 89.5 %
i: 2	e_train: 0.47493 bits	a_train: 74.2 %	e_val: 0.49232 bits	a_val: 90.5 %
i: 3	e_train: 0.44804 bits	a_train: 74.9 %	e_val: 0.46489 bits	a_val: 91.3 %
i: 4	e_train: 0.43954 bits	a_train: 74.9 %	e_val: 0.46276 bits	a_val: 91.6 %
i: 5	e_train: 0.42702 bits	a_train: 75.1 %	e_val: 0.46872 bits	a_val: 91.2 %
