In [1]:
import torch
import torch.nn.functional as F

class SRAMSimulator:
    def __init__(self, sram_sizes):
        self.sram = {'input': None, 'weight': None, 'output': None,
                     'input_2': None, 'weight_2': None, 'output_2': None,
                     'input_3': None, 'weight_3': None, 'output_3': None}
        self.sram_sizes = sram_sizes  # Now it's a dictionary that holds sizes for each SRAM segment
        self.transfer_counts = {'load_total': 0, 'clear_total': 0, 'input_load': 0, 'input_clear': 0,
                                'weight_load': 0, 'weight_clear': 0, 'output_load': 0, 'output_clear': 0}

    def load_to_sram(self, name, data ,test_num):
        # Calculate the size of the new data to be added
        data_size = data.numel() * data.element_size()
        temp_out = None
        # Retrieve existing data if it exists, or initialize an empty placeholder
        existing_data = self.sram.get(name, None)
        if existing_data is not None:
            # Calculate the new total size after appending
            total_data_size = existing_data.numel() * existing_data.element_size() + data_size
        else:
            total_data_size = data_size
        
        # Check if the SRAM can accommodate the new total size of data
        if total_data_size > self.sram_sizes[name]:
            temp_out = self.test_output(name)
            self.clear_sram(name)
            existing_data = None
            

        if existing_data is not None:
            if name == 'input':
                self.sram[name] = torch.cat((existing_data, data), dim=0)
            elif name == 'weight':
                self.sram[name] = torch.cat((existing_data, data), dim=1)
            elif name == 'output':
               self.sram[name] = torch.cat((existing_data, data), dim=0)
        else:
            self.sram[name] = data

        # Update the transfer counts
        self.transfer_counts['load_total'] += 1
        self.transfer_counts[f'{name}_load'] += 1
      #  print(temp_out)
        return temp_out
    def test_output(self,name):
        out = self.sram[name]
        return out
    def clear_sram(self, name):
       # print(name+"clear")
        self.sram[name] = None
        self.transfer_counts['clear_total'] += 1
        self.transfer_counts[f'{name}_clear'] += 1
      #  print(f"Cleared {name} from SRAM. Transfer counts: {self.transfer_counts}")

    def calculate_attention(self, embedding, wq, wk, wv, test_num):
      
        self.load_to_sram('input',embedding,test_num)

        # Calculate Q
        self.load_to_sram('weight', wq,test_num)
        q = torch.matmul(self.sram['input'], self.sram['weight'])
        self.clear_sram('weight')
        self.load_to_sram('output',q,test_num)
        self.clear_sram('output')

        
        # Calculate K
        self.load_to_sram('weight', wk,test_num) 
        k = torch.matmul(self.sram['input'], self.sram['weight'])
        self.clear_sram('weight')
        self.load_to_sram('output',k,test_num)
        self.clear_sram('output')

        # Calculate V
        self.load_to_sram('weight', wv,test_num)
        v = torch.matmul(self.sram['input'], self.sram['weight'])
        self.clear_sram('weight')
        self.load_to_sram('output',v,test_num)
        self.clear_sram('output')


        self.clear_sram('input')
        # Calculate Attention Scores
        k_t = k.transpose(-2, -1)
        self.load_to_sram('weight', k_t,test_num)
        self.load_to_sram('input', q,test_num)
        attn_scores = torch.matmul(self.sram['input'], self.sram['weight'])
        self.clear_sram('input')
        self.clear_sram('weight')
        self.load_to_sram('output',attn_scores,test_num)
        self.clear_sram('output')


        attn_weights = F.softmax(attn_scores / (k.size(-1) ** 0.5), dim=-1)

        # Calculate Attention Output
        self.load_to_sram('input', attn_weights,test_num)
        self.load_to_sram('weight', v,test_num)
        attn_output = torch.matmul(self.sram['input'], self.sram['weight'])
        self.clear_sram('input')
        self.clear_sram('weight')
        self.load_to_sram('output', attn_output,test_num)
        self.clear_sram('output')

        output_size = attn_output.numel() * attn_output.element_size()
        print(f"Final output size: {output_size} bytes")

        
        return attn_output

    def get_transfer_count(self, name='total'):
        return self.transfer_counts[name]

    def cal_att2(self, embedding, wq, wk, wv, test_num, dimension, max_seq):
        #print("====================== Word Embedding ======================")
        Q = None
        K = None
        V = None
        att_u = None
        final_out = None
        for i in range(0,max_seq,test_num):
            clear_input = self.load_to_sram('input',embedding[0][i:i+test_num],test_num)
            counter = 0
            for j in range(0,dimension,test_num):
                clear_weight = self.load_to_sram('weight',wq[:,j:j+test_num],test_num)
                counter2 = 0
                if i == 0 and j == 0:
                    # Assuming that existing_data and data are both PyTorch tensors
                    # We need to concatenate along a specific dimension (0 by default)
                    q2 = torch.matmul(self.sram['input'],self.sram['weight'])

                elif clear_input != None and clear_weight != None:
                    q2 = torch.matmul(self.sram['input'],self.sram['weight'])
                    counter = 0

                elif clear_input != None and clear_weight == None:
                    q2 = torch.matmul(self.sram['input'],self.sram['weight'][:,counter2:counter2+test_num])
                    counter = 0
                 
                elif clear_input == None and clear_weight != None:
                    q2 = torch.matmul(self.sram['input'][counter:counter+test_num],self.sram['weight'])
                  
                else:
                    q2 = torch.matmul(self.sram['input'][counter:counter+test_num],self.sram['weight'][:,counter2:counter2+test_num])
                   
                a = self.load_to_sram('output',q2,test_num)
                counter2 += 1
                if a!= None and Q == None:
                    Q = a
                elif a != None and Q != None :
                    Q = torch.cat((Q,a),dim = 0)
            counter += 1
        #Q = self.sram['output']
        if Q == None:
            Q = self.sram['output']
        elif Q.shape[0] != dimension*max_seq/2:
            Q = torch.cat((Q,self.sram['output']),dim = 0)
        Q = Q.view(max_seq,dimension)
        #print(Q.shape)
        self.clear_sram('input')
        self.clear_sram('weight')
        self.clear_sram('output')

        for i in range(0,max_seq,test_num):
            input = embedding[0][i:i+test_num]
            counter = 0
            for j in range(0,dimension,test_num):
                clear_weight = self.load_to_sram('weight',wv[:,j:j+test_num],test_num)
                counter2 = 0
                if i == 0 and j == 0:
                    # Assuming that existing_data and data are both PyTorch tensors
                    # We need to concatenate along a specific dimension (0 by default)
                    v2 = torch.matmul(input,self.sram['weight'])

                elif clear_input != None and clear_weight != None:
                    v2 = torch.matmul(input,self.sram['weight'])
                    counter = 0

                elif clear_input != None and clear_weight == None:
                    v2 = torch.matmul(input,self.sram['weight'][:,counter2:counter2+test_num])
                    counter = 0
                 
                elif clear_input == None and clear_weight != None:
                    v2 = torch.matmul(input,self.sram['weight'])
                  
                else:
                    v2 = torch.matmul(input,self.sram['weight'][:,counter2:counter2+test_num])

                a = self.load_to_sram('output',v2,test_num)
                counter2 += 1
                if a!= None and V == None:
                    V = a
                elif a != None and V != None :
                    V = torch.cat((V,a),dim = 0)
        #Q = self.sram['output']
        if V == None:
            V = self.sram['output']
        elif V.shape[0] != dimension*max_seq/2:
            V = torch.cat((V,self.sram['output']),dim = 0)
        V = V.view(max_seq,dimension)
        self.clear_sram('input')
        self.clear_sram('weight')
        self.clear_sram('output')

        for i in range(0,max_seq,test_num):
            input = embedding[0][i:i+test_num]
            counter = 0
            for j in range(0,dimension,test_num):
                clear_weight = self.load_to_sram('weight',wk[:,j:j+test_num],test_num)
                counter2 = 0
                if i == 0 and j == 0:
                    # Assuming that existing_data and data are both PyTorch tensors
                    # We need to concatenate along a specific dimension (0 by default)
                    k2 = torch.matmul(input,self.sram['weight'])

                elif clear_input != None and clear_weight != None:
                    k2 = torch.matmul(input,self.sram['weight'])
                    counter = 0

                elif clear_input != None and clear_weight == None:
                    k2 = torch.matmul(input,self.sram['weight'][:,counter2:counter2+test_num])
                    counter = 0
                 
                elif clear_input == None and clear_weight != None:
                    k2 = torch.matmul(input,self.sram['weight'])
                  
                else:
                    k2 = torch.matmul(input,self.sram['weight'][:,counter2:counter2+test_num])
                a = self.load_to_sram('output',k2,test_num)
                counter2 += 1
                if a!= None and K == None:
                    K = a
                elif a != None and K != None :
                    K = torch.cat((K,a),dim = 0)
        if K == None:
            K = self.sram['output']
        elif K.shape[0] != dimension*max_seq/2:
            K = torch.cat((K,self.sram['output']),dim = 0)

        K = K.view(max_seq,dimension)
        self.clear_sram('input')
        self.clear_sram('weight')
        self.clear_sram('output')
        k_t = K.transpose(-2,-1)
       # print(k_t.shape)

        #print("====================== qkt ======================")
        #print(Q.shape)
        #print(k_t.shape)
        for i in range(0, max_seq, test_num):
            clear_input = self.load_to_sram('input', Q[i:i+test_num], test_num)
            counter = 0
            for j in range(0, max_seq, test_num):
                clear_weight = self.load_to_sram('weight', k_t[:,j:j+test_num], test_num)
                counter2 = 0
                if i == 0 and j == 0:
                    # Assuming that existing_data and data are both PyTorch tensors
                    # We need to concatenate along a specific dimension (0 by default)
                    att = torch.matmul(self.sram['input'],self.sram['weight'])

                elif clear_input != None and clear_weight != None:
                    att = torch.matmul(self.sram['input'],self.sram['weight'])
                    counter = 0

                elif clear_input != None and clear_weight == None:
                    att = torch.matmul(self.sram['input'],self.sram['weight'][:,counter2 :counter2+test_num])
                    counter = 0
                 
                elif clear_input == None and clear_weight != None:
                    att = torch.matmul(self.sram['input'][counter:counter+test_num],self.sram['weight'])
                  
                else:
                    att = torch.matmul(self.sram['input'][counter:counter+test_num],self.sram['weight'][:,counter2:counter2+test_num])

                a = self.load_to_sram('output',att,test_num)
                counter2 += 1
                if a!= None and att_u == None:
                    att_u = a
                elif a != None and att != None :
                    att_u = torch.cat((att_u,a),dim = 0)
            self.clear_sram('weight')
        if att_u == None:
            att_u = self.sram['output']
        elif att_u.shape[0] != dimension*max_seq/2:
            att_u = torch.cat((att_u,self.sram['output']),dim = 0)
        att_u = att_u.view(max_seq,max_seq)
        #print(att_u.shape)
        attn_weights = F.softmax(att_u/ (K.size(-1) ** 0.5), dim=-1)
       # print(attn_weights.shape)
        self.clear_sram('weight')
        self.clear_sram('input')
        self.clear_sram('output')

       # print("====================== attention ======================")
        #print(attn_weights.shape)
       # print(V.shape)
        for i in range(0, max_seq, test_num):
            clear_input = self.load_to_sram('input', attn_weights[i:i+test_num], test_num)
            counter = 0
            for j in range(0, dimension, test_num):
                counter2 = 0
                clear_weight = self.load_to_sram('weight', V[:,j:j+test_num], test_num)
                if i == 0 and j == 0:
                    # Assuming that existing_data and data are both PyTorch tensors
                    # We need to concatenate along a specific dimension (0 by default)
                    out = torch.matmul(self.sram['input'],self.sram['weight'])

                elif clear_input != None and clear_weight != None:
                    out = torch.matmul(self.sram['input'],self.sram['weight'])
                    counter = 0

                elif clear_input != None and clear_weight == None:
                    out = torch.matmul(self.sram['input'],self.sram['weight'][:,counter2:counter2+test_num])
                    counter = 0
                 
                elif clear_input == None and clear_weight != None:
                    out = torch.matmul(self.sram['input'][counter:counter+test_num],self.sram['weight'])
                  
                else:
                    out = torch.matmul(self.sram['input'][counter:counter+test_num],self.sram['weight'][:,counter2:counter2+test_num])
                
                a = self.load_to_sram('output',out,test_num)
                counter2 += 1
                if a!= None and final_out == None:
                    final_out = a
                elif a != None and final_out != None :
                    final_out = torch.cat((final_out,a),dim = 0)
            self.clear_sram('weight')
            
        if final_out == None:
            final_out = self.sram['output']
        elif final_out.shape[0] != dimension*max_seq/2:
            final_out = torch.cat((final_out,self.sram['output']),dim = 0)
        final_out = final_out.view(max_seq,dimension)
        #print(final_out.shape)
        self.clear_sram('output')
        print("Finish") 

    def cal_att3(self, embedding, wq, wk, wv, test_num, dimension, max_seq):
        print("====================== Word Embedding ======================")
        Q = None
        K = None
        V = None
        Att = None
        final_out = None
        for i in range(0,max_seq,test_num):
            clear_input = self.load_to_sram('input',embedding[0][i:i+test_num],test_num)
            counter = 0
            counter2 = 0
            for j in range(0,dimension,test_num):
                clear_weight_Q = self.load_to_sram('weight',wq[:,j:j+test_num],test_num)
                clear_weight_K = self.load_to_sram('weight',wk[:,j:j+test_num],test_num)
                clear_weight_V = self.load_to_sram('weight',wv[:,j:j+test_num],test_num)

                if clear_weight_K != None:
                    #print('b')
                    temp_q = torch.matmul(self.sram['input'][i:i+test_num],clear_weight_K[:,counter2:counter2+test_num])
                    counter2 = 0
                    temp_k = torch.matmul(self.sram['input'][i:i+test_num],self.sram['weight'][:,counter2:counter2+test_num])
                    temp_v = torch.matmul(self.sram['input'][i:i+test_num],self.sram['weight'][:,counter2+2:counter2+test_num+2])
                    counter2 += 4
                elif clear_weight_V != None:
                    #print('c')
                    temp_k = torch.matmul(self.sram['input'][i:i+test_num],clear_weight_V[:,counter2:counter2+test_num])
                    temp_q = torch.matmul(self.sram['input'][i:i+test_num],clear_weight_V[:,counter2-2:counter2+test_num-2])
                    counter2 = 0
                    temp_v = torch.matmul(self.sram['input'][i:i+test_num],self.sram['weight'][:,counter2:counter2+test_num])
                    counter2 += 2
                else:
                    #print('d')
                    temp_q = torch.matmul(self.sram['input'][i:i+test_num],self.sram['weight'][:,counter2:counter2+test_num])
                    temp_k = torch.matmul(self.sram['input'][i:i+test_num],self.sram['weight'][:,counter2+2:counter2+test_num+2])
                    temp_v = torch.matmul(self.sram['input'][i:i+test_num],self.sram['weight'][:,counter2+4:counter2+test_num+4])
                    counter2 += 6          
                if i == 0 and j == 0:
                    Q = temp_q
                    K = temp_k
                    V = temp_v
                else:
                    Q = torch.cat((Q,temp_q),dim = 0)
                    K = torch.cat((K,temp_k),dim = 0)
                    V = torch.cat((V,temp_v),dim = 0)
                
            print(Q.shape)
            print(K.shape)
            print(V.shape)
            Q.view(max_seq,dimension)
            V.view(max_seq,dimension)
            K.view(max_seq,dimension)
            
            k_t = K.transpose(-2,-1)
            
            print("====================== qkt ======================")
            for i in range(0, max_seq, test_num):
                clear_input = self.load_to_sram('input', Q[i:i+test_num], test_num)
                counter2 =  0
                for j in range(0, max_seq, test_num):
                    clear_weight = self.load_to_sram('weight', k_t[:,j:j+test_num], test_num)
                    counter2 = 0
                    if i == 0 and j == 0:
                        # Assuming that existing_data and data are both PyTorch tensors
                        # We need to concatenate along a specific dimension (0 by default)
                        att = torch.matmul(self.sram['input'],self.sram['weight'])

                    elif clear_input != None and clear_weight != None:
                        att = torch.matmul(self.sram['input'],self.sram['weight'])
                        counter = 0

                    elif clear_input != None and clear_weight == None:
                        att = torch.matmul(self.sram['input'],self.sram['weight'][:,counter2 :counter2+test_num])
                        counter = 0
                    
                    elif clear_input == None and clear_weight != None:
                        att = torch.matmul(self.sram['input'][counter:counter+test_num],self.sram['weight'])
                    
                    else:
                        att = torch.matmul(self.sram['input'][counter:counter+test_num],self.sram['weight'][:,counter2:counter2+test_num])

                    a = self.load_to_sram('output',att,test_num)
                    counter2 += 1
                    if a!= None and att_u == None:
                        att_u = a
                    elif a != None and att != None :
                        att_u = torch.cat((att_u,a),dim = 0)
                self.clear_sram('weight')
            if att_u == None:
                att_u = self.sram['output']
            elif att_u.shape[0] != dimension*max_seq/2:
                att_u = torch.cat((att_u,self.sram['output']),dim = 0)
            att_u = att_u.view(max_seq,max_seq)
            print(att_u.shape)
            attn_weights = F.softmax(att_u/ (K.size(-1) ** 0.5), dim=-1)
            print(attn_weights.shape)
            self.clear_sram('weight')
            self.clear_sram('input')
            self.clear_sram('output')
            

In [2]:
# Example usage:
dimension = 512
max_seq = 100
test_num = 2
sram_sizes = {  
    'input':  10000* max_seq * 4,  # SRAM size for input
    'weight': dimension * dimension * 4,  # SRAM size for weights
    'output': dimension * dimension * 4   # SRAM size for output
}
simulator = SRAMSimulator(sram_sizes)

embedding = torch.randn( 1, max_seq, dimension)  # Example input embedding
wq = torch.randn(dimension, dimension)  # Example weight for Q
wk = torch.randn(dimension, dimension)  # Example weight for K
wv = torch.randn(dimension, dimension)  # Example weight for V
attn_output = simulator.cal_att3(embedding, wq, wk, wv, test_num, dimension, max_seq)

        

torch.Size([512, 2])
torch.Size([512, 2])
torch.Size([512, 2])


RuntimeError: shape '[100, 512]' is invalid for input of size 1024

In [None]:
dimension = 512
max_seq = 100
test_num = 2
import csv

head = ['buffer_size', 'load_total', 'clear_total', 'input_load', 'input_clear', 'weight_load', 'weight_clear', 'output_load', 'output_clear']

for sram_buffer in range(1024,262144,10240):
    sram_sizes = {  
        'input':   sram_buffer * 4,  # SRAM size for input
        'weight':  sram_buffer * 4,  # SRAM size for weights
        'output':  sram_buffer * 4   # SRAM size for output
    }
    simulator = SRAMSimulator(sram_sizes)

    embedding = torch.randn( 1, max_seq, dimension)  # Example input embedding
    wq = torch.randn(dimension, dimension)  # Example weight for Q
    wk = torch.randn(dimension, dimension)  # Example weight for K
    wv = torch.randn(dimension, dimension)  # Example weight for V
    attn_output = simulator.cal_att2(embedding, wq, wk, wv, test_num, dimension, max_seq)
    if sram_buffer == 1024:
        data = [head,[sram_buffer,simulator.transfer_counts['load_total'],simulator.transfer_counts['clear_total'],simulator.transfer_counts['input_load'],
                simulator.transfer_counts['input_clear'],simulator.transfer_counts['weight_load'],simulator.transfer_counts['weight_clear'],
                simulator.transfer_counts['output_load'],simulator.transfer_counts['output_clear']]]   
    else:
        data = [[sram_buffer,simulator.transfer_counts['load_total'],simulator.transfer_counts['clear_total'],simulator.transfer_counts['input_load'],
                simulator.transfer_counts['input_clear'],simulator.transfer_counts['weight_load'],simulator.transfer_counts['weight_clear'],
                simulator.transfer_counts['output_load'],simulator.transfer_counts['output_clear']]]
    with open('outputt.csv', mode= 'a+' , newline='') as file:
        writer = csv.writer(file)
        writer.writerows(data)

            
    

{'load_total': 262, 'clear_total': 1, 'input_load': 1, 'input_clear': 0, 'weight_load': 261, 'weight_clear': 1, 'output_load': 0, 'output_clear': 0}
Finish
Finish
Finish
Finish
Finish
Finish
Finish
Finish
Finish
Finish
Finish
Finish
Finish
Finish
Finish
Finish
Finish
Finish
Finish
Finish
Finish
Finish
Finish
Finish
Finish
Finish


In [None]:
print(simulator.transfer_counts['load_total'])

In [None]:
# Define the base SRAM sizes
base_sram_sizes = {  
    'input': 512 * 100 * 4,
    'weight': 512 * 512 * 4,
    'output': 512 * 512 * 4
}



# Define the factors by which to multiply each SRAM size to create different configurations
factors = [0.5, 1, 2, 4]

# Generate combinations of SRAM sizes
sram_combinations = []
for input_factor in factors:
    for weight_factor in factors:
        for output_factor in factors:
            sram_combinations.append({
                'input': base_sram_sizes['input'] * input_factor,
                'weight': base_sram_sizes['weight'] * weight_factor,
                'output': base_sram_sizes['output'] * output_factor
            })

# Now sram_combinations contains all the different SRAM size configurations


In [None]:
# Example usage:
dimension = 512
max_seq = 100
test_num = 2
sram_sizes = {  
    'input':  dimension* max_seq * 4,  # SRAM size for input
    'weight': dimension * dimension * 4,  # SRAM size for weights
    'output': dimension * dimension * 4   # SRAM size for output
}
simulator = SRAMSimulator(sram_sizes)

embedding = torch.randn( 1, max_seq, dimension)  # Example input embedding
wq = torch.randn(dimension, dimension)  # Example weight for Q
wk = torch.randn(dimension, dimension)  # Example weight for K
wv = torch.randn(dimension, dimension)  # Example weight for V
attn_output = simulator.cal_att2(embedding, wq, wk, wv, test_num, dimension, max_seq)