In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import numpy as np
from typing import List, Tuple, Dict, Optional
import math
import copy

Exploring Neural Architecture Search (Dummy Data)

In [2]:
class SearchableAttention(nn.Module):
  #Searcheable multi-head attention with architecture choices
  def __init__(self, d_model: int, max_heads: int = 8):
    super().__init__()
    self.d_model = d_model
    self.max_heads = max_heads

    #Architecture parameters (Learnable)
    self.alpha_heads = nn.Parameter(torch.randn(max_heads))
    self.alpha_patterns = nn.Parameter(torch.randn(3)) # full, local, sparse

    #Attention Components for Diff head counts
    self.attentions = nn.ModuleList([nn.MultiheadAttention(d_model, num_heads=i+1, batch_first=True) for i in range(max_heads)])

  def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
    #Softmax over architecture choices
    head_weights = F.softmax(self.alpha_heads, dim=0)
    pattern_weights = F.softmax(self.alpha_patterns, dim=0)

    #weighted combination of diff attention heads
    outputs = []
    for i, attn in enumerate(self.attentions):
      out, _ = attn(x, x, x, attn_mask=mask)
      outputs.append(head_weights[i] * out)

    combined = sum(outputs)

    #Apply diff attention patterns (simplified)
    full_attn = combined
    local_attn = combined * 0.8 # Simulate local attention
    sparse_attn = combined * 0.6 # Simulate Sparse attention

    final_output = (pattern_weights[0] * full_attn + pattern_weights[1] * local_attn + pattern_weights[2] * sparse_attn)
    return final_output

In [3]:
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
  #Softmax over architecture choices
  head_weights = F.softmax(self.alpha_heads, dim=0)
  pattern_weights = F.softmax(self.alpha_patterns, dim=0)

  #weighted combination of diff attention heads
  outputs = []
  for i, attn in enumerate(self.attentions):
    out, _ = attn(x, x, x, attn_mask=mask)
    outputs.append(head_weights[i] * out)

  combined = sum(outputs)

  #Apply diff attention patterns (simplified)
  full_attn = combined
  local_attn = combined * 0.8 # Simulate local attention
  sparse_attn = combined * 0.6 # Simulate Sparse attention

  final_output = (pattern_weights[0] * full_attn + pattern_weights[1] * local_attn + pattern_weights[2] * sparse_attn)
  return final_output

In [4]:
class SearchableFFN(nn.Module):
  #Searchable feed-forward network with diff expansion ratios
  def __init__(self, d_model: int):
    super().__init__()
    self.d_model = d_model

    #Architecture Parameters for diff expansion ratios
    self.alpha_expansion = nn.Parameter(torch.randn(4))

    #Diff FFN configurations
    self.ffns = nn.ModuleList([
    self._make_ffn(d_model, int(d_model * ratio))
    for ratio in [2,4,6,8]])

  def _make_ffn(self, d_in: int, d_hidden: int):
    return nn.Sequential(
        nn.Linear(d_in, d_hidden),
        nn.GELU(),
        nn.Linear(d_hidden, d_in)
    )

  def forward(self, x: torch.Tensor):
    weights = F.softmax(self.alpha_expansion, dim=0)
    outputs = []
    for i, ffn in enumerate(self.ffns):
        outputs.append(weights[i] * ffn(x))
    return sum(outputs)

In [5]:
class SearchableTransformerBlock(nn.Module):
  #Searchable transformer block with multiple architectural choices

  def __init__(self, d_model: int, max_heads: int = 8, dropout: float =0.1):
    super().__init__()
    self.d_model = d_model

    #Architecture choice: skip connection patterns
    self.alpha_skip = nn.Parameter(torch.rand(3))

    self.attention = SearchableAttention(d_model, max_heads)
    self.ffn = SearchableFFN(d_model)

    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
    #Softmax over skip connection patterns
    skip_weights = F.softmax(self.alpha_skip, dim=0)

    #Diff normalization Patterns
    #Pre Norm

    prenorm_x = self.norm1(x)
    prenorm_attn = self.attention(prenorm_x, mask)
    prenorm_out = self.dropout(prenorm_attn)
    prenorm_ffn = self.ffn(self.norm2(prenorm_out))
    prenorm_final = prenorm_out + self.dropout(prenorm_ffn)

    #Post Norm
    postnorm_attn= self.attention(x, mask)
    postnorm_out = self.norm1(x + self.dropout(postnorm_attn)) # Corrected typo
    postnorm_ffn = self.ffn(postnorm_out)
    postnorm_final = self.norm2(postnorm_out + self.dropout(postnorm_ffn))

    #No norm (residual only)
    nonorm_attn = self.attention(x, mask)
    nonorm_out = x + self.dropout(nonorm_attn)
    nonorm_ffn = self.ffn(nonorm_out)
    nonorm_final = nonorm_out + self.dropout(nonorm_ffn)

    # Weighted combination
    return (skip_weights[0] * prenorm_final +
            skip_weights[1] * postnorm_final +
            skip_weights[2] * nonorm_final)

In [6]:
class SearchableLLM(nn.Module):

  # Neural Architecture Search enabled Language Model

  def __init__(self, vocab_size: int,
               d_model: int = 512,
               max_layers: int = 12,
               max_heads: int = 8,
               max_seq_len: int = 1024,
               dropout: float = 0.1):
    super().__init__()

    self.vocab_size = vocab_size
    self.d_model = d_model
    self.max_layers = max_layers

    #Architecture Parameters for Layer Count
    self.alpha_layers = nn.Parameter(torch.randn(max_layers)) # Corrected typo

    #Embeddings
    self.token_embedding = nn.Embedding(vocab_size, d_model)
    self.pos_embedding = nn.Embedding(max_seq_len, d_model)

    #Searchable transformer blocks
    self.blocks = nn.ModuleList([
        SearchableTransformerBlock(d_model, max_heads, dropout)
        for _ in range(max_layers)
    ])

    #Output Projection
    self.norm_f = nn.LayerNorm(d_model) # Corrected typo
    self.lm_head = nn.Linear(d_model, vocab_size)

    self.dropout = nn.Dropout(dropout)

  def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
    batch_size, seq_length = input_ids.shape
    device = input_ids.device

    #Positional Encoding (EMBEDDINGS)
    positions = torch.arange(seq_length, device=device).unsqueeze(0)
    x = self.token_embedding(input_ids) + self.pos_embedding(positions)
    x = self.dropout(x)

    #Weighted combination of diff layer depths
    layer_weights = F.softmax(self.alpha_layers, dim=0)

    #Process through all blocks but weight their contributions
    layer_outputs = [x]
    current_x = x

    for i, block in enumerate(self.blocks):
      current_x = block(current_x, attention_mask)
      layer_outputs.append(current_x)

    #Weighted sum of outputs from diff depths
    weighted_output = sum(layer_weights[i] * layer_outputs[i+1]
                          for i in range(len(self.blocks)))

    x = self.norm_f(weighted_output) # Corrected typo
    logits = self.lm_head(x)

    return logits

  def get_architecture_params(self):
    #Get all architecture parameters for optmization
    arch_params = []
    arch_params.append(self.alpha_layers)

    for block in self.blocks:
      arch_params.append(block.alpha_skip)
      arch_params.append(block.attention.alpha_heads)
      arch_params.append(block.attention.alpha_patterns)
      arch_params.append(block.ffn.alpha_expansion)

    return arch_params

  def get_model_params(self):
    # Get all model parameters (excluding architecture params)
    arch_params_ids = {id(p) for p in self.get_architecture_params()}
    return [p for p in self.parameters() if id(p) not in arch_params_ids]

In [15]:
class NASTrainer:
# Neural Architecture search trainer using DARTS methodology
  def __init__(self,
               model: SearchableLLM,
               train_loader,
               val_loader,
               device: str = 'cuda'):
    self.model = model.to(device)
    self.train_loader = train_loader
    self.val_loader = val_loader
    self.device = device


    # Seperate optimizers for model and architecture parameters
    self.model_optimizer = optim.AdamW(model.get_model_params(),
                                       lr=1e-4,
                                       weight_decay=1e-4)
    self.arch_optimizer = optim.Adam(model.get_architecture_params(), lr=3e-4, betas=(0.5, 0.999))

    self.criterion = nn.CrossEntropyLoss()

  def train_step(self, batch):
    #Single training step
    input_ids, labels = batch
    input_ids, labels = input_ids.to(self.device), labels.to(self.device)

    #shift labels for language modeling
    shift_labels = labels[..., 1:].contiguous()
    # Pass sliced input_ids to the model
    shift_logits = self.model(input_ids[:, :-1])[..., :].contiguous()

    loss = self.criterion(
        shift_logits.view(-1, shift_logits.shape[-1]),
        shift_labels.view(-1)
    )

    return loss

  def search_step(self):
    # DARTS search step: optimize architecture parameters
    self.model.train()

    #Get validation batch for architecture optimization
    try:
      val_batch = next(iter(self.val_loader))
    except StopIteration:
      return 0.0

    #Optimize architecture parameters on validation set
    self.arch_optimizer.zero_grad()
    val_loss = self.train_step(val_batch)
    val_loss.backward()
    self.arch_optimizer.step()

    return val_loss.item()

  def model_step(self):
    #Optimize model parameters on training set
    self.model.train()

    try:
      train_batch = next(iter(self.train_loader))
    except StopIteration:
      return 0.0

    self.model_optimizer.zero_grad()
    train_loss = self.train_step(train_batch)
    train_loss.backward()
    self.model_optimizer.step()

    return train_loss.item()

  def train_epoch(self):
    #Train for one epoch alternating btw model and architecture optimization
    total_model_loss = 0.0
    total_arch_loss = 0.0
    n_steps = 0

    for batch_idx, batch in enumerate(self.train_loader): # Corrected unpacking
      # Initialize loss variables for the current batch
      model_loss = 0.0
      arch_loss = 0.0

      #Alternate btw model and architecture optimization
      if batch_idx % 2 == 0:
        model_loss = self.model_step()
        total_model_loss += model_loss
      else:
        arch_loss = self.search_step()
        total_arch_loss += arch_loss

      n_steps += 1

      if batch_idx % 100 == 0:
        print(f"Batch {batch_idx}, Model Loss: {model_loss:.4f}, Arch Loss: {arch_loss:.4f}")

    return total_model_loss / n_steps, total_arch_loss / n_steps

  def get_final_architecture(self):
    #extract the final discovered architecture
    arch = {}

    #layer depth
    layer_probs = F.softmax(self.model.alpha_layers, dim=0)
    arch['optimal_layers'] = torch.argmax(layer_probs).item() + 1
    arch['layer_weights'] = layer_probs.detach().cpu().numpy()

    #Block Level Choices
    block_configs = []
    for i, block in enumerate(self.model.blocks):
      block_config = {}

      #skip connection patterns
      skip_probs = F.softmax(block.alpha_skip, dim=0)
      skip_patterns = ['prenorm', 'postnorm', 'nonorm']
      block_config['skip_pattern'] = skip_patterns[torch.argmax(skip_probs).item()]


      #Attention heads
      head_probs = F.softmax(block.attention.alpha_heads, dim=0)
      block_config['num_heads'] = torch.argmax(head_probs).item() + 1

      #Attention pattern
      pattern_probs = F.softmax(block.attention.alpha_patterns, dim=0)
      patterns = ['full', 'local', 'sparse']
      block_config['attention_pattern'] = patterns[torch.argmax(pattern_probs).item()]

      #FFN expansion
      exp_probs = F.softmax(block.ffn.alpha_expansion, dim=0)
      expansions = [2, 4, 6, 8]
      block_config['ffn_expansion'] = expansions[torch.argmax(exp_probs).item()]

      block_configs.append(block_config)

    arch['block_configs'] = block_configs

    return arch

  #Example usage and utility functions
  def create_dummy_data(vocab_size: int = 1000, seq_len: int = 128, batch_size: int = 8):
    #Create Dummy data for testing
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
    labels = torch.randint(0, vocab_size, (batch_size, seq_len))
    return [(input_ids, labels)]

In [17]:
def main():
  #example usage of NAS for LLMs
  #Configuration
  vocab_size = 1000
  d_model = 264 # Changed d_model to be divisible by 1, 2, 3, and 4
  max_layers = 6
  max_heads = 4
  seq_len = 128

  #create model
  model = SearchableLLM(
    vocab_size=vocab_size,
    d_model=d_model,
    max_layers=max_layers,
    max_heads=max_heads,
    max_seq_len=seq_len,

    )

  #create dummy data
  train_data = NASTrainer.create_dummy_data(vocab_size=vocab_size, seq_len=seq_len, batch_size=8)
  val_data = NASTrainer.create_dummy_data(vocab_size=vocab_size, seq_len=seq_len, batch_size=4)

  #Initialize trainer
  trainer = NASTrainer(model, train_data, val_data, device='cpu')

  #Training loop
  num_epochs = 5 # You can change this to set the number of epochs
  print("Starting Neural Architecture Search...")
  for epoch in range(num_epochs):
    model_loss, arch_loss = trainer.train_epoch()
    print(f"Epoch {epoch}, Model Loss: {model_loss:.4f}, Arch Loss: {arch_loss:.4f}")

  #extract final architecture
  final_arch = trainer.get_final_architecture()
  print("\nDiscovered Architecture:")
  print(f"Optimal layers: {final_arch['optimal_layers']} ")
  for i, block_config in enumerate(final_arch['block_configs'][:3]):
    print(f"Block {i}: {block_config}")

main()

Starting Neural Architecture Search...
Batch 0, Model Loss: 7.0470, Arch Loss: 0.0000
Epoch 0, Model Loss: 7.0470, Arch Loss: 0.0000
Batch 0, Model Loss: 6.9650, Arch Loss: 0.0000
Epoch 1, Model Loss: 6.9650, Arch Loss: 0.0000
Batch 0, Model Loss: 6.8759, Arch Loss: 0.0000
Epoch 2, Model Loss: 6.8759, Arch Loss: 0.0000
Batch 0, Model Loss: 6.8074, Arch Loss: 0.0000
Epoch 3, Model Loss: 6.8074, Arch Loss: 0.0000
Batch 0, Model Loss: 6.7370, Arch Loss: 0.0000
Epoch 4, Model Loss: 6.7370, Arch Loss: 0.0000

Discovered Architecture:
Optimal layers: 4 
Block 0: {'skip_pattern': 'postnorm', 'num_heads': 3, 'attention_pattern': 'sparse', 'ffn_expansion': 8}
Block 1: {'skip_pattern': 'nonorm', 'num_heads': 1, 'attention_pattern': 'full', 'ffn_expansion': 6}
Block 2: {'skip_pattern': 'nonorm', 'num_heads': 2, 'attention_pattern': 'full', 'ffn_expansion': 2}
