In [9]:
import torch
from matplotlib import pyplot as plt
import numpy as np
import time
from torch import nn
def in_context_loss(model, Z, y):
    N = Z.shape[1]-1
    d = Z.shape[2]-1
    output = model(Z)
    diff = output[:,N,d]+y
    loss = ((diff)**2).mean()
    return loss
def attention(P,Q,Z, activation = None):
    B= Z.shape[0]
    N = Z.shape[1]-1
    d = Z.shape[2]-1
    P_full =  torch.cat([P,torch.zeros(1,d).to(device)],dim=0)
    P_full =  torch.cat([P_full,torch.zeros(d+1,1).to(device)],dim=1)
    P_full = P_full.clone()
    P_full[d, d] = 1
    Q_full = torch.cat([Q, torch.zeros(1,d).to(device)],dim=0)
    Q_full = torch.cat([Q_full, torch.zeros(d+1,1).to(device)],dim=1)
    A = torch.eye(N+1).to(device)
    A[N,N] = 0
    Attn = torch.einsum('BNi, ij, BMj -> BNM', (Z,Q_full,Z))
    if activation is not None:
        Attn = activation(Attn)
    key = torch.einsum('ij, BNj -> BNi', (P_full,Z))
    Output = torch.einsum('BNM,ML, BLi -> BNi', (Attn,A,key))
    return Output /N
def generate_data_sine(N=10, B=1000, d=1):
    # Sample amplitude a and phase ρ for each task
    a = torch.FloatTensor(B, 1).uniform_(1, 2).cuda()
    rho = torch.FloatTensor(B, 1).uniform_(np.pi/2, np.pi).cuda()
    # Sample inputs x uniformly between -5 and 5
    X = torch.FloatTensor(B, N, d).uniform_(-5, 5).cuda()
    X_test = torch.FloatTensor(B, 1, d).uniform_(-5, 5).cuda()
    y = a.unsqueeze(1) * torch.sin(rho.unsqueeze(1) + X)
    y_test = a.unsqueeze(1) * torch.sin(rho.unsqueeze(1) + X_test)
    # Combine X and y for the full dataset
    X_comb = torch.cat([X, X_test], dim=1)
    y_comb = torch.cat([y, torch.zeros(B, 1, 1).cuda()], dim=1)
    Z = torch.cat([X_comb, y_comb], dim=2)
    return Z.to('cuda'), y_test.squeeze(1).to('cuda')
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MLP, self).__init__()
        self.embed = nn.Sequential(
            nn.Linear(input_dim, output_dim-1))
        self.mlp2 = nn.Sequential(
            nn.Linear(output_dim-1, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, output_dim-1)
        )
    def forward(self, Z):
        X = Z[:,:,0].unsqueeze(2)
        out = self.embed(X) + self.mlp2(self.embed(X))
        out = torch.cat([out,Z[:,:,1].unsqueeze(2)],dim=2)
        return out
        #return self.layers(x)
class Transformer_MLP(nn.Module):
    def __init__(self, n_layer, n_head, d, var):
        super(Transformer_MLP, self).__init__()
        # Ensure the MLP is initialized with the correct dimensions
        hidden_size=160
        output_dim=40
        self.mlp = MLP(input_dim=1, hidden_dim=hidden_size, output_dim=output_dim)  # Adjust hidden_dim as needed
        self.n_layer = n_layer
        self.n_head = n_head
        self.register_parameter('allparam', nn.Parameter(torch.zeros(n_layer, n_head, 2, output_dim-1, output_dim-1)))
        with torch.no_grad():
            self.allparam.normal_(0, var)
    def forward(self, Z):
        B, N, _ = Z.shape
        # Apply MLP to each token
        #for i in range(N):
        #Z_new = torch.zeros([B,N,20]).cuda()#Z.clone()
        #Z_new[:, i, :] = self.mlp(Z[:, i, :]) #+ Z[:, i, :]
        Z = self.mlp(Z)
        #Z = Z_new
        #print(Z.shape)
        # Apply self-attention layers
        for i in range(self.n_layer):
            Zi = Z
            residues = 0
            for j in range(self.n_head):
                Pij = self.allparam[i, j, 0, :, :]
                Qij = self.allparam[i, j, 1, :, :]
                residues = residues + attention(Pij, Qij, Zi)
            Z = Zi + residues
        return Z
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_layer = 3  # number of layers of transformer
d = 1        # dimension for sine data
n_head = 1   # 1-headed attention
B = 1000     # minibatch size for sine data
var = 0.0001 # initializations scale of transformer parameter
max_iters = 10000  # Number of Iterations to run
learning_rates = [0.0001, 0.0005, 0.001, 0.002]
# Function to run training
def train_transformer(optimizer_name, learning_rate, max_epochs):
    # Initialize model and optimizer
    model = Transformer_MLP(n_layer, n_head, d, var).to(device)
    if optimizer_name == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
    elif optimizer_name == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.9))
    losses = []
    for epoch in range(max_epochs):
        # Generate data
        Z, y = generate_data_sine(N=100, B=B)
        Z, y = Z.to(device), y.to(device)
        # Training step
        optimizer.zero_grad()
        loss = in_context_loss(model, Z, y)
        loss.backward()
        optimizer.step()
        # Record loss
        losses.append(loss.item())
        # Optionally print progress
        if epoch % 100 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item()}')
    return losses
# Tuning learning rate
tuning_epochs = 200
optimal_lr = {}
for optimizer_name in ['adam']:
    best_lr = None
    best_loss = float('inf')
    for lr in learning_rates:
        print(f"Training with {optimizer_name} optimizer, Learning Rate: {lr}")
        losses = train_transformer(optimizer_name, lr, tuning_epochs)
        avg_loss = np.mean(losses[-20:])  # Average loss over last epochs
        if avg_loss < best_loss:
            best_loss = avg_loss
            best_lr = lr
    optimal_lr[optimizer_name] = best_lr
    print(f"Optimal Learning Rate for {optimizer_name}: {best_lr}")
# Extended training with optimal learning rates
final_losses = {}
for optimizer_name, lr in optimal_lr.items():
    print(f"Extended Training with {optimizer_name} optimizer, Learning Rate: {lr}")
    losses = train_transformer(optimizer_name, lr, max_iters)
    final_losses[optimizer_name] = losses
# Plotting losses
plt.figure(figsize=(10, 6))
for optimizer_name, losses in final_losses.items():
    plt.plot(losses, label=f'{optimizer_name} LR={optimal_lr[optimizer_name]}')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.yscale('log')
plt.legend()
plt.show()

Training with sgd optimizer, Learning Rate: 0.001
Epoch 0, Loss: 1.6699517965316772
Epoch 100, Loss: 1.185174584388733
Training with sgd optimizer, Learning Rate: 0.005
Epoch 0, Loss: 1.2234455347061157
Epoch 100, Loss: 1.1859381198883057
Training with sgd optimizer, Learning Rate: 0.01
Epoch 0, Loss: 1.57686448097229
Epoch 100, Loss: 1.1376088857650757
Training with sgd optimizer, Learning Rate: 0.02
Epoch 0, Loss: 1.2858954668045044
Epoch 100, Loss: 1.163874864578247
Optimal Learning Rate for sgd: 0.005
Training with adam optimizer, Learning Rate: 0.001
Epoch 0, Loss: 1.2446575164794922
Epoch 100, Loss: 1.1583644151687622
Training with adam optimizer, Learning Rate: 0.005
Epoch 0, Loss: 1.3413984775543213
Epoch 100, Loss: 1.0970300436019897
Training with adam optimizer, Learning Rate: 0.01
Epoch 0, Loss: 1.4299474954605103
Epoch 100, Loss: 1.1471378803253174
Training with adam optimizer, Learning Rate: 0.02
Epoch 0, Loss: 1.5571529865264893
Epoch 100, Loss: 1.1547386646270752
Optimal

KeyboardInterrupt: 