In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
import math

In [2]:
@dataclass
class MistralConfig:
    vocab_size = 1000
    d_model = 128
    d_ff = 1024
    layers = 6
    n_head = 4
    kv_head = 2
    max_pos_embed = 512
    sliding_window = 256
    hidden = 'silu'
    eps = 1e-6
    device = "cuda" if torch.cuda.is_available() else "cpu"
    batch_size = 16
    seq_len = 64
    head_dim = d_model // n_head

In [3]:
class RMSNorm(nn.Module):
    def __init__(self,d_model,eps=1e-6):
        super().__init__()
        self.d_model = d_model
        self.weights = nn.Parameter(torch.ones(d_model))
        self.eps = eps 

    def forward(self,x):
        mean = torch.mean(x**2,dim=-1,keepdim=True)
        rms = torch.sqrt(mean+self.eps)
        x = (x / rms) * self.weights
        return x
    

In [4]:
def precompute_freqs_cis(head_dim,max_pos_embed,theta=10000.0):
    freqs = 1.0 / theta ** (torch.arange(0,head_dim,2).float() / head_dim)
    pos = torch.arange(max_pos_embed)
    angles = torch.outer(pos,freqs)
    return torch.polar(torch.ones_like(angles),angles)

def apply_rotary_embed(x,freqs_cis):
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2))
    freqs_cis = freqs_cis.reshape(1,x.shape[1],1,-1)
    x_rotated = torch.view_as_real(x_complex * freqs_cis)
    x_rotated = x_rotated.reshape(x.shape)
    return x_rotated.type_as(x)

In [5]:
class InputEmbedding(nn.Module):
    def __init__(self,vocab_size,d_model):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size,d_model)
        
    def forward(self,x):
        return self.embedding(x)

In [6]:
def repeat_kv(x,n_rep):
    batch_size,seq_len,kv_head,head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:,:,:,None,:].expand(batch_size,seq_len,kv_head,n_rep,head_dim).reshape(batch_size,seq_len,kv_head * n_rep, head_dim)
        )
    
class MistralAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config = config
        self.d_model = config.d_model
        self.n_head = config.n_head
        self.kv_head = config.kv_head
        self.head_dim = config.head_dim
        self.sliding_window = config.sliding_window
        self.n_rep = self.n_head // self.kv_head
        
        self.q_proj = nn.Linear(self.d_model,self.d_model,bias=False)
        self.k_proj = nn.Linear(self.d_model,self.kv_head * self.head_dim , bias=False)
        self.v_proj = nn.Linear(self.d_model,self.kv_head * self.head_dim ,bias=False)
        self.o_proj = nn.Linear(self.d_model,self.d_model,bias=False)

        self.cache_k = None
        self.cache_v = None

    def _init_cache(self, batch_size):
        device = self.q_proj.weight.device
        self.cache_k = torch.zeros(
            (batch_size, self.sliding_window, self.kv_head, self.head_dim),
            device=device
        )
        self.cache_v = torch.zeros(
            (batch_size, self.sliding_window, self.kv_head, self.head_dim),
            device=device
        )
        
    def _roll_cache(self):
        self.cache_k = torch.roll(self.cache_k, shifts=-self.config.seq_len, dims=1)
        self.cache_v = torch.roll(self.cache_v, shifts=-self.config.seq_len, dims=1)
    
    def forward(self, x, freqs_cis):
        batch_size,seq_len,_ = x.shape

        if self.cache_k is None or self.cache_k.size(0) != batch_size:
            self._init_cache(batch_size)
            
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        
        q = q.view(batch_size, seq_len, self.n_head, self.head_dim)
        k = k.view(batch_size, seq_len, self.kv_head, self.head_dim)
        v = v.view(batch_size, seq_len, self.kv_head, self.head_dim)
    
        q = apply_rotary_embed(q, freqs_cis)
        k = apply_rotary_embed(k,freqs_cis)
        
        self._roll_cache()
        
        valid_len = min(seq_len,self.sliding_window)
        self.cache_k[:,-valid_len:,:,:] = k.detach()
        self.cache_v[:,-valid_len:,:,:] = v.detach()

        keys = repeat_kv(self.cache_k, self.n_rep)
        values = repeat_kv(self.cache_v, self.n_rep)
        
        q = q.transpose(1,2)
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)
        
        attn_weights = (q @ keys.transpose(-2,-1)) / math.sqrt(self.config.head_dim)
        attn_weights = F.softmax(attn_weights,dim=-1)
        output = (attn_weights @ values).transpose(1,2).reshape(batch_size,seq_len,-1)
        
        return self.o_proj(output)


In [7]:
class MistralMLP(nn.Module):
    def __init__(self,config):
        super().__init__()
        d_ff = int(2 * config.d_ff / 3)
        self.gate_proj = nn.Linear(config.d_model,d_ff,bias=False)
        self.layer1 = nn.Linear(config.d_model,d_ff,bias=False)
        self.layer2 = nn.Linear(d_ff,config.d_model,bias=False)
        self.act = F.silu

    def forward(self,x):
        return self.layer2(self.act(self.gate_proj(x)) * self.layer1(x))

In [8]:
class MistralBlock(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.attn = MistralAttention(config)
        self.mlp = MistralMLP(config)
        self.norm1 = RMSNorm(config.d_model,config.eps)
        self.norm2 = RMSNorm(config.d_model,config.eps)

    def forward(self,x,freqs_cis):
        x= x + self.attn(self.norm1(x),freqs_cis)
        out = x + self.mlp(self.norm2(x))
        return out


In [9]:
class Transformer(nn.Module):
        def __init__(self,config):
            super().__init__()
            assert config.vocab_size != -1
            self.config = config
            self.embed = InputEmbedding(config.vocab_size,config.d_model)
            self.layers = nn.ModuleList([MistralBlock(config) for _ in range(config.layers)])
            self.norm = RMSNorm(config.d_model,config.eps)
            self.output = nn.Linear(config.d_model,config.vocab_size,bias=False)
            self.freqs_cis = precompute_freqs_cis(config.head_dim,config.max_pos_embed).to(config.device)
        def reset_cache(self):
            for layer in self.layers:
                layer.attn.cache_k = None
                layer.attn.cache_v = None
            
        def forward(self,x):
            batch_size,seq_len = x.shape
            x = self.embed(x)
            freq_cis = self.freqs_cis[:seq_len]
            
            for layer in self.layers:
                x = layer(x,freq_cis)
            x = self.norm(x)
            logits = self.output(x)
            return logits 

In [10]:
if __name__ == "__main__":
    config = MistralConfig()
    model = Transformer(config).to(config.device)
    dummy_input = torch.randint(config.vocab_size, (config.batch_size, config.seq_len), device=config.device)
    logits = model(dummy_input)
    print(f"Logits shape: {logits.shape}")

Logits shape: torch.Size([16, 64, 1000])


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

class CustomTokenizer:
    def __init__(self, text):
        self.chars = sorted(list(set(text)))
        self.stoi = {ch: i for i, ch in enumerate(self.chars)}
        self.itos = {i: ch for ch, i in self.stoi.items()}

    def encode(self, text):
        return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)

    def decode(self, tokens):
        return ''.join([self.itos[token] for token in tokens])

    def vocab_size(self):
        return len(self.chars)

def download_and_preprocess_dataset():
    import requests
    url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    response = requests.get(url)
    text = response.text

    tokenizer = CustomTokenizer(text)
    encoded = tokenizer.encode(text)
    return encoded, tokenizer

class ShakespeareDataset(Dataset):
    def __init__(self, data, seq_len):
        self.data = data
        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        x = self.data[idx:idx + self.seq_len]
        y = self.data[idx + 1:idx + self.seq_len + 1]
        return x, y

def generate_sample_output(model, tokenizer, config, start_text="The", seq_len=100):
    model.eval()
    model.reset_cache()
    
    # Encode and add batch dimension correctly
    encoded = tokenizer.encode(start_text)
    generated = encoded.unsqueeze(0).to(config.device)  # Shape [1, seq_len]
    
    for _ in range(seq_len):
        context = generated[:, -config.sliding_window:]
        with torch.no_grad():
            logits = model(context)
        next_token = torch.argmax(logits[:, -1, :], dim=-1)
        generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
    
    return tokenizer.decode(generated[0].tolist())

def train_model(model, dataloader, optimizer, criterion, config, tokenizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}", leave=False)
        
        for i, (x, y) in enumerate(progress_bar):
            x, y = x.to(config.device), y.to(config.device)

            # Clear gradients FIRST
            optimizer.zero_grad()

            # Forward pass
            logits = model(x)
            loss = criterion(logits.view(-1, config.vocab_size), y.view(-1))

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix({"Loss": f"{loss.item():.4f}"})

        # Log epoch average loss
        print(f"Epoch [{epoch + 1}/{epochs}], Average Loss: {total_loss / len(dataloader):.4f}")
        
        # Generate a sample output
        sample_output = generate_sample_output(model, tokenizer, config, start_text="The", seq_len=100)
        print(f"Sample Output after Epoch {epoch + 1}:\n{sample_output}\n")

# Main script
def main():
    # Load and preprocess the dataset
    data, tokenizer = download_and_preprocess_dataset()

    # Update the configuration with the vocabulary size
    config = MistralConfig()
    config.vocab_size = tokenizer.vocab_size()

    # Create the dataset and dataloader
    seq_len = config.seq_len
    dataset = ShakespeareDataset(data, seq_len)
    dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, drop_last=True)

    # Initialize the model, criterion, and optimizer
    model = Transformer(config).to(config.device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-4)

    # Train the model
    train_model(model, dataloader, optimizer, criterion, config, tokenizer, epochs=10)

    # Save the model and tokenizer
    torch.save(model.state_dict(), "mistral_transformer.pth")
    torch.save(tokenizer, "custom_tokenizer.pth")
    print("Model and tokenizer saved!")

if __name__ == "__main__":
    main()


                                                                                

Epoch [1/10], Average Loss: 0.2663
Sample Output after Epoch 1:
The,  ddddlllllll ttte ishllf tttss st tp st t ts t t t t t thanenef theeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee



                                                                                

Epoch [2/10], Average Loss: 0.0444
Sample Output after Epoch 2:
Theeedddequurhhrrdddex?'dddd'sdddrrrddhheelllhiqullhiquququ?? dddddodood qummmollllllllllllllll quququq



                                                                                

Epoch [3/10], Average Loss: 0.0368
Sample Output after Epoch 3:
Thedd HAUME:

CLIZHARLARICES:

Soodd bbbbbbbbbbb:

KEEEN ttttttterros teeerron t ttteerron thereeed the



                                                                                

Epoch [4/10], Average Loss: 0.0344
Sample Output after Epoch 4:
Therrryy   nnneee,


DUCENNII:
Whous the the the the the ttrus tttttttttttttttttttttttttttttttttttttttt



                                                                                

Epoch [5/10], Average Loss: 0.0330
Sample Output after Epoch 5:
The my hillll thee theee the the the the theeee theeee theee theeeeeeeee thee theeeee theee theeee thee



Epoch 10/10:  95%|██████████▍| 66524/69708 [12:17<00:36, 87.28it/s, Loss=0.0324]

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define the MoE Layer
class MoELayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_experts, top_k=2):
        super(MoELayer, self).__init__()
        
        # Define the experts as a list of fully connected layers
        self.experts = nn.ModuleList([nn.Linear(input_dim, hidden_dim) for _ in range(num_experts)])
        
        # Define the gating network (a simple feed-forward neural network)
        self.gate = nn.Linear(input_dim, num_experts)
        
        # Number of experts to select
        self.top_k = top_k

    def forward(self, x):
        # Get the gating weights (probabilities) from the gating network
        gating_weights = self.gate(x)  # Shape: [batch_size, num_experts]
        
        # Apply softmax to get probabilities (normalized weights for each expert)
        gating_probs = F.softmax(gating_weights, dim=-1)  # Shape: [batch_size, num_experts]
        
        # Get the indices of the top-K experts for each input
        top_k_values, top_k_indices = torch.topk(gating_probs, self.top_k, dim=-1)
        
        # Gather the top-K expert outputs
        expert_outputs = [self.experts[i](x) for i in range(len(self.experts))]
        
        # Select only the top-K experts' outputs for each input in the batch
        # Create a tensor to hold the top-K expert outputs
        top_k_expert_outputs = torch.stack([expert_outputs[i] for i in top_k_indices], dim=1)
        
        # Weighted sum of the top-K expert outputs
        weighted_output = torch.sum(top_k_expert_outputs * top_k_values.unsqueeze(-1), dim=1)
        
        return weighted_output

# Define the full MoE model
class MoEModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_experts, num_classes, top_k=2):
        super(MoEModel, self).__init__()
        self.moe_layer = MoELayer(input_dim, hidden_dim, num_experts, top_k)
        self.fc_out = nn.Linear(hidden_dim, num_classes)  # Final output layer

    def forward(self, x):
        # Apply the MoE layer
        x = self.moe_layer(x)
        # Apply the final classification layer
        x = self.fc_out(x)
        return x

# Example of using the MoE Model
if __name__ == "__main__":
    # Hyperparameters
    input_dim = 128   # Input dimension (e.g., feature size)
    hidden_dim = 64   # Hidden dimension
    num_experts = 8   # Number of experts
    num_classes = 10  # Number of output classes (for classification)
    top_k = 3         # Top-K experts to select

    # Create a random input tensor (batch size of 32)
    batch_size = 32
    x = torch.randn(batch_size, input_dim)

    # Instantiate the MoE model
    model = MoEModel(input_dim, hidden_dim, num_experts, num_classes, top_k)

    # Forward pass
    output = model(x)

    # Print the output
    print(output.shape)  # Expected output: [batch_size, num_classes]


TypeError: only integer tensors of a single element can be converted to an index