In [8]:

import math
import torch
from torch import nn
import numpy as np
from torch.utils.data import DataLoader, Dataset 
class MHAttention(nn.Module):
    def __init__(self, d_model=4,Qbias=False,Vbias=True,Mask=False,OnlyAtt=True,n_heads=1):
        super(MHAttention, self).__init__()
        self.dk=d_model//n_heads
        self.Q=nn.Linear(d_model,d_model,bias=Qbias)
        self.K=nn.Linear(d_model,d_model,bias=Qbias)
        self.V=nn.Linear(d_model,d_model,bias=Vbias)
        self.O=nn.Linear(d_model,d_model,bias=False)
        self.Mask=Mask
       
        self.OnlyAtt=OnlyAtt
        self.n_heads=n_heads
        self.d_model=d_model
    def forward(self, src,mask=None):
        # Q=self.Q(src)
       
        # K=self.K(src)
        # V=self.V(src)
        # output=torch.zeros(size=src.shape).to(device)
        # for i in range(self.n_heads):

    
        #     O=torch.matmul(Q[:,:,i*self.n_heads:(i+1)*self.n_heads],torch.transpose(K[:,:,i*self.n_heads:(i+1)*self.n_heads],dim0=1,dim1=2))/np.sqrt(self.dk)

        #     if self.Mask:
        #         #n
        #         seq_len=src.size(1)
        #         #Upper triangular Matrix of Trues
        #         mask = torch.triu(torch.ones(seq_len, seq_len, device=src.device), diagonal=1).bool()

        #         O=O.masked_fill(mask,float('-inf'))
            
        #     att=torch.softmax(O,dim=2)
       
        #     output[:,:,i*self.n_heads:(i+1)*self.n_heads]=torch.matmul(att,V[:,:,i*self.n_heads:(i+1)*self.n_heads])
        
        # #V=att_out
        # if self.OnlyAtt:
        #     return output[:,0]
        # else:
        #     return output
        batch_size, seq_len, _ = src.shape

        # Compute Q, K, V
        Q = self.Q(src).view(batch_size, seq_len, self.n_heads, self.dk).transpose(1, 2)  # (B, H, S, D_k)
        K = self.K(src).view(batch_size, seq_len, self.n_heads, self.dk).transpose(1, 2)  
        V = self.V(src).view(batch_size, seq_len, self.n_heads, self.dk).transpose(1, 2)  

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.dk)  # (B, H, S, S)

        # Apply mask if needed
        if mask is not None:
            # Ensure mask is broadcastable to scores shape.
            # Mask should have -inf where you want to block attention.
            scores = scores.masked_fill(mask, float('-inf'))

        # Compute attention weights
        attn = torch.softmax(scores, dim=-1)  # (B, H, S, S)
        
        # Apply attention to values
        context = torch.matmul(attn, V)  # (B, H, S, D_k)

        # Reshape & project output
        if self.n_heads>1:
            context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)  # (B, S, D)
        output = self.O(context)

        return output if not self.OnlyAtt else output[:, 0]

class MHAttentionApproximatorRes(nn.Module):
    def __init__(self,d_model=1,Qbias=False,Vbias=True,i_dim=1,o_dim=1,n_heads=1,mask=None):
        super(MHAttentionApproximatorRes,self).__init__()
        self.MHAttention=MHAttention(d_model,Qbias=Qbias,Vbias=Vbias,OnlyAtt=True,n_heads=n_heads,Mask=mask)
        self.embedding=nn.Linear(i_dim,d_model)
        
        self.O=nn.Linear(d_model,i_dim)

    def forward(self,src,mask=None):
        emb=self.embedding(src)
        
        emb=self.MHAttention(emb,mask=mask)
        emb=self.O(emb)

        return src+emb
    

class MLP(nn.Module):
    def __init__(self,input_dim,hidden_dim,output_dim):
        super(MLP,self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim,hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim,output_dim)
        )
    def forward(self,x):
        return self.model(x)

class TLayerSpec(nn.Module):
    def __init__(self,d_model,Qbias=False,Vbias=True,n_heads=1,Mask=False,OnlyAtt=False,dimff=8,d_in=1,d_out=1,mask=None):
        super(TLayerSpec,self).__init__()
        self.attention=MHAttentionApproximatorRes(d_model,Qbias=Qbias,Vbias=Vbias,n_heads=n_heads,mask=mask,i_dim=d_in)
        self.mlp=MLP(d_in,dimff,d_out)

    def forward(self,src,mask=None):
        attn_out=self.attention(src,mask=mask)
        n1=src+attn_out
        mlp1=self.mlp(n1)
        return mlp1
    
class MyTransformerSpec(nn.Module):
    def __init__(self,
                 Qbias=False,Vbias=True,n_heads=1,Mask=False,OnlyAtt=True,
                 layers=1,d_model=1,
                 dimFeedForward=[[1,16,1]]):
        super(MyTransformerSpec,self).__init__()
        self.tlayers= nn.ModuleList([
            TLayerSpec(d_in=dimFeedForward[i][0],d_model=d_model,
                Qbias=Qbias,Vbias=Vbias,n_heads=n_heads,
                Mask=Mask,OnlyAtt=False,dimff=dimFeedForward[i][1],d_out=dimFeedForward[i][2])
                for i in range(layers)
        ])
        self.layers=layers
        

    def forward(self, x, X):
        # Concatenate x (external token) with token matrix X along dimension 1.
        # x: (batch_size, 1, d_model) and X: (batch_size, seq_len, d_model)
        ce = torch.cat([x, X], dim=1).to(torch.float32)  # Now shape: (batch_size, total_seq_len, d_model)
        
        # Create an attention mask so that the external token (index 0) does not attend to any tokens in X.
        batch_size, total_len, _ = ce.shape  # total_len = 1 + seq_len
        # Create a mask of shape (total_len, total_len)
        # We want to block attention from token at index 0 to indices 1...total_len-1.
        mask = torch.zeros(total_len, total_len, dtype=torch.bool, device=ce.device)
        mask[0, 1:] = True
        # Expand the mask over the batch dimension
        mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
   
        # Pass the mask into each transformer layer
        for i in range(self.layers):
            ce = self.tlayers[i](ce, mask=mask)
        
        return ce[:,0,:]
   

In [9]:
def random_positive_definite_matrix(n):
    # Step 1: Generate a random matrix
    A = np.random.randn(n, n)

    # Step 2: Perform QR decomposition to get an orthonormal matrix Q
    Q, _ = np.linalg.qr(A)

    # Step 3: Generate a diagonal matrix of positive eigenvalues
    eigenvalues = np.abs(np.random.randn(n)) + 1  # Ensuring eigenvalues are positive

    # Step 4: Create a diagonal matrix of eigenvalues
    Lambda = np.diag(eigenvalues)

    # Step 5: Compute the positive definite matrix A = Q * Lambda * Q.T
    positive_definite_matrix = np.dot(Q, np.dot(Lambda, Q.T))

    return positive_definite_matrix

# Example usage
n = 4  # Size of the matrix
matrix = random_positive_definite_matrix(n)

In [15]:
class ContextDataset(Dataset):
    def __init__(self, num_samples=1000,d=1,n=10,V=torch.zeros(1)):
        self.data = torch.tensor(np.random.normal(0,2,size=[num_samples,n+1,d]),dtype=torch.float64)
        self.Y=[]
        for i in range(num_samples):
            x=self.data[i][0]
            X=self.data[i][1:]
            current=torch.zeros(1)
           
            v=torch.matmul(V,x)
            
            for i in range(n):
                for j in range(n):
                    current+=torch.abs(torch.matmul(v,X[i]-X[j]))
                
                

            self.Y.append(current/(n**2))
                    
        self.V=V
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.Y[idx]
        return x[0],x[1:],y
dataset=ContextDataset(1000,4,10,torch.tensor(matrix,dtype=torch.float64))
dataset2=ContextDataset(1000,4,10,torch.tensor(matrix,dtype=torch.float64))
dataloader=DataLoader(dataset,100)
dataloader2=DataLoader(dataset2,100)

In [18]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from torch import optim
epochs=10000
#with sdpa_kernel([SDPBackend.MATH]):
model = MyTransformerSpec(dimFeedForward=[[4,16,1]],n_heads=1,layers=1,d_model=4).to(device)
criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)
TrainErrorAdam = []
TestErrorAdam=[]
for epoch in range(epochs):
    # Training Phase
    model.train()
    total_loss = 0
    for batch in dataloader:
        x,tok, targets = batch
        x,tok, targets = x.unsqueeze(1).to(device),tok.to(device), targets.to(device)
       
        optimizer.zero_grad()
        outputs = model(x,tok)  # (batch_size, seq_len, vocab_size)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    avg_train_loss = total_loss / len(dataloader)

    total_loss = 0
    TrainErrorAdam.append(avg_train_loss)
    if epoch%100==0:
        print(TrainErrorAdam[-1])
    model.eval() 
    for batch in dataloader2:
        x,tok, targets = batch
        x,tok, targets = x.unsqueeze(1).to(device),tok.to(device), targets.to(device)
       
        optimizer.zero_grad()
        outputs = model(x,tok)  # (batch_size, seq_len, vocab_size)
        loss = criterion(outputs, targets)
        
        total_loss += loss.item()
    
    avg_train_loss = total_loss / len(dataloader)


    TestErrorAdam.append(avg_train_loss)
    if epoch%100==0:
        print(TestErrorAdam[-1])



#with sdpa_kernel([SDPBackend.MATH]):


119.01182022094727
112.55735473632812
7.600576543807984
7.565220832824707
7.454439640045166
7.468077659606934
7.372849321365356
7.404358720779419
7.291619920730591
7.391693735122681
7.189855003356934
7.428511047363282
7.116103315353394
7.479602336883545
7.047230625152588
7.513826179504394
7.003833293914795
7.566576671600342
6.964155483245849
7.597510099411011
6.9509388446807865
7.628451585769653
6.945015621185303
7.637031984329224
6.946006345748901
7.639179992675781
6.941211843490601
7.64763879776001
6.938714218139649
7.644886207580567
6.937544536590576
7.646400308609008
6.941845083236695
7.645077466964722
6.941077184677124
7.644385528564453
6.936041450500488
7.653135824203491
6.940794801712036
7.650232267379761
6.938742351531983
7.645614337921143
6.935816478729248
7.649838829040528
6.939172887802124
7.651364803314209
6.936597871780395
7.654007911682129
6.936481142044068
7.64950122833252
6.93694167137146
7.657779884338379
6.936551141738891
7.650830507278442
6.934852743148804
7.65298633