In [1]:
import pandas as pd
df = pd.read_csv("C:\\MachineLearning\\UsamaKenway\\DeepSeek-R1-OpenHermes-2.5\\datasets\\CoT_dataset\\dataset_0010.csv")

In [2]:
df.shape

(10000, 22)

In [3]:
df = df.dropna(subset=['instruct_prompt'])

In [4]:
df = df[:5000]

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import torch.utils.checkpoint as checkpoint

# Load LLaMA tokenizer and MiniLM embeddings
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
embedding_model =  SentenceTransformer("NovaSearch/stella_en_400M_v5", trust_remote_code=True, device='cuda').cuda() #= SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
print(f"Embedding model's dimensions: {len(embedding_model.encode('ok'))}")

def get_embeddings(texts):
    return torch.tensor(embedding_model.encode(texts), dtype=torch.float32)

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)

# Multi-Head Self-Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.qkv_proj = nn.Linear(d_model, d_model * 3)
        self.out_proj = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv_proj(x).reshape(B, T, 3, self.num_heads, self.d_k).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn_weights = (q @ k.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_output = (attn_weights @ v).transpose(1, 2).reshape(B, T, C)
        return self.out_proj(attn_output)

# Feed-Forward Network
class FeedForward(nn.Module):
    def __init__(self, d_model, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(d_model, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, d_model)
    
    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))

# Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, hidden_dim):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, hidden_dim)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
    
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x
        
class TextDataset(Dataset):
    def __init__(self, df, max_len=100):
        self.texts = df['instruct_prompt'].tolist()
        # Reshape embeddings to match expected dimensions
        self.embeddings = [
            torch.nn.functional.pad(
                get_embeddings(text).unsqueeze(0), 
                (0, 0, 0, max_len - 1)
            ) if len(get_embeddings(text).unsqueeze(0)) < max_len 
            else get_embeddings(text).unsqueeze(0)[:max_len]
            for text in self.texts
        ]
        self.tokenized = [
            tokenizer(
                text, 
                return_tensors='pt', 
                padding="max_length", 
                truncation=True, 
                max_length=max_len
            )["input_ids"].squeeze(0) 
            for text in self.texts
        ]
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        return self.embeddings[idx], self.tokenized[idx]

class Transformer(nn.Module):
    def __init__(self, d_model, num_heads, hidden_dim, num_layers, vocab_size, max_len=5000):
        super().__init__()
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([TransformerBlock(d_model, num_heads, hidden_dim) for _ in range(num_layers)])
        self.output_layer = nn.Linear(d_model, vocab_size)
        # Add input projection layer to handle single token embeddings
        self.input_projection = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        # Ensure input has correct shape [batch_size, seq_len, d_model]
        if len(x.shape) == 2:
            x = x.unsqueeze(1)
        
        # Project input
        x = self.input_projection(x)
        x = self.pos_encoding(x)
        for layer in self.layers:
            # x = layer(x)
            x = checkpoint.checkpoint(layer, x)  # Apply gradient checkpointing per layer
        return self.output_layer(x)

    def decode(self, logits):
        token_ids = torch.argmax(logits, dim=-1)
        return [tokenizer.decode(ids, skip_special_tokens=True) for ids in token_ids]

# Training setup
def train_model(model, train_loader, num_epochs, device, vocab_size):
    optimizer = optim.AdamW(model.parameters(), lr=5e-5)
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device).float(), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            
            # Reshape outputs and targets for loss calculation
            outputs = outputs.view(-1, vocab_size)
            targets = targets.view(-1)
            
            loss = criterion(outputs, targets)
            loss.backward()
            
            # Add gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            total_loss += loss.item()
            
            if batch_idx % 10 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")



A matching Triton is not available, some optimizations will not be enabled
Traceback (most recent call last):
  File "C:\Python\py311env\Lib\site-packages\xformers\__init__.py", line 57, in _is_triton_available
    import triton  # noqa
    ^^^^^^^^^^^^^
  File "C:\Python\py311env\Lib\site-packages\triton\__init__.py", line 20, in <module>
    from .runtime import (
  File "C:\Python\py311env\Lib\site-packages\triton\runtime\__init__.py", line 1, in <module>
    from .autotuner import (Autotuner, Config, Heuristics, autotune, heuristics)
  File "C:\Python\py311env\Lib\site-packages\triton\runtime\autotuner.py", line 9, in <module>
    from .jit import KernelInterface
  File "C:\Python\py311env\Lib\site-packages\triton\runtime\jit.py", line 12, in <module>
    from ..runtime.driver import driver
  File "C:\Python\py311env\Lib\site-packages\triton\runtime\driver.py", line 1, in <module>
    from ..backends import backends
  File "C:\Python\py311env\Lib\site-packages\triton\backends\__ini

Embedding model's dimensions: 1024


In [6]:
%%time
embedding_model.encode("ok")

CPU times: total: 15.6 ms
Wall time: 37 ms


array([ 0.20521261, -0.08252388, -2.8384926 , ..., -1.2719604 ,
        0.5373522 , -0.71815664], dtype=float32)

In [35]:
def train_model(model, train_loader, start_epoch, num_epochs, device, vocab_size):
    """
    Memory-optimized training function using Adafactor and gradient accumulation
    """
    # Import Adafactor from transformers
    from transformers.optimization import Adafactor
    
    # Initialize Adafactor optimizer with memory-efficient settings
    optimizer = Adafactor(
        model.parameters(),
        eps=(1e-30, 1e-3),
        clip_threshold=1.0,
        decay_rate=-0.8,
        beta1=None,
        weight_decay=0.01,
        scale_parameter=True,
        relative_step=True,  # Use relative step sizing
        warmup_init=True     # Enable warmup
    )
    
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
    
    # Gradient accumulation steps
    gradient_accumulation_steps = 4
    
    # Enable gradient checkpointing for memory efficiency
    # model.gradient_checkpointing_enable() # im not using it because of error
    
    # Initialize mixed precision training
    scaler = torch.cuda.amp.GradScaler()
    
    for epoch in range(start_epoch, num_epochs):
        model.train()
        total_loss = 0
        optimizer.zero_grad()
        
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            # Move to device and handle data types
            inputs = inputs.to(device).float()
            targets = targets.to(device)
            
            # Use automatic mixed precision
            with torch.cuda.amp.autocast():
                outputs = model(inputs)
                outputs = outputs.view(-1, vocab_size)
                targets = targets.view(-1)
                
                # Scale loss by gradient accumulation steps
                loss = criterion(outputs, targets) / gradient_accumulation_steps
            
            # Scale gradients and backward pass
            scaler.scale(loss).backward()
            
            # Update weights if we've accumulated enough gradients
            if (batch_idx + 1) % gradient_accumulation_steps == 0:
                # Unscale gradients for clipping
                scaler.unscale_(optimizer)
                
                # Clip gradients
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                # Optimizer step with scaler
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            
            # Track loss (multiply by accumulation steps to get true loss)
            total_loss += loss.item() * gradient_accumulation_steps
            
            # Print progress
            if batch_idx % 50 == 0:
                avg_loss = total_loss / (batch_idx + 1)
                print(f"Epoch {epoch+1}, Batch {batch_idx}, "
                      f"Loss: {loss.item() * gradient_accumulation_steps:.4f}, "
                      f"Avg Loss: {avg_loss:.4f}")
                
                # # Print memory usage if on cuda
                # if device.type == 'cuda':
                #     print(f"GPU Memory: "
                #           f"{torch.cuda.memory_allocated(device) / 1024**2:.1f}MB / "
                #           f"{torch.cuda.memory_reserved(device) / 1024**2:.1f}MB")
        
        # Compute average loss for the epoch
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")
        
        # Optional: Save checkpoint after each epoch
        torch.save(model, 'model.pth')
        # torch.save({
        #     'epoch': epoch,
        #     'model_state_dict': model.state_dict(),
        #     'optimizer_state_dict': optimizer.state_dict(),
        #     'loss': avg_loss,
        # }, f'checkpoint.pt') #f'checkpoint_epoch_{epoch+1}.pt'

# Memory optimization settings before training
def optimize_memory():
    """Apply memory optimizations before training"""
    import gc
    import torch
    
    # Clear memory
    gc.collect()
    torch.cuda.empty_cache()
    
    # Set memory allocator settings
    if torch.cuda.is_available():
        torch.cuda.set_per_process_memory_fraction(0.9)  # Use 90% of available memory
        torch.backends.cudnn.benchmark = True

In [21]:
len(embedding_model.encode("g"))

1024

In [9]:
from tqdm import tqdm

class TextDataset(Dataset):
    def __init__(self, df, max_len=100):
        self.texts = df['instruct_prompt'].tolist()
        
        # Track progress with tqdm
        self.embeddings = []
        self.tokenized = []
        
        for text in tqdm(self.texts, desc="Processing Texts", leave=False):
            try:
                # Handle embeddings
                embedding = get_embeddings(text)
                if isinstance(embedding, torch.Tensor):  # Ensure it's a tensor
                    embedding = torch.nn.functional.pad(
                        embedding.unsqueeze(0), 
                        (0, 0, 0, max_len - 1)
                    ) if len(embedding.unsqueeze(0)) < max_len else embedding.unsqueeze(0)[:max_len]
                    self.embeddings.append(embedding)
                else:
                    raise ValueError(f"Embedding not tensor for text: {text}")
                
                # Handle tokenization
                tokenized = tokenizer(
                    text, 
                    return_tensors='pt', 
                    padding="max_length", 
                    truncation=True, 
                    max_length=max_len
                )["input_ids"].squeeze(0)
                self.tokenized.append(tokenized)
                
            except Exception as e:
                print(f"Skipping row due to error: {e}")
                continue
    
    def __len__(self):
        return len(self.embeddings)
    
    def __getitem__(self, idx):
        return self.embeddings[idx], self.tokenized[idx]


In [32]:
import torch

def load_checkpoint(model, optimizer, checkpoint_path, device):
    """Load model and optimizer states from a checkpoint"""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1  # Resume from the next epoch
    print(f"Checkpoint loaded! Resuming from epoch {start_epoch}")
    return model, optimizer, start_epoch

def train_model(model, train_loader, num_epochs, device, vocab_size, checkpoint_path=None):
    from transformers.optimization import Adafactor
    import torch.nn as nn

    optimizer = Adafactor(
        model.parameters(),
        eps=(1e-30, 1e-3),
        clip_threshold=1.0,
        decay_rate=-0.8,
        beta1=None,
        weight_decay=0.01,
        scale_parameter=True,
        relative_step=True,
        warmup_init=True
    )
    
    criterion = nn.CrossEntropyLoss()
    gradient_accumulation_steps = 4
    # scaler = torch.cuda.amp.GradScaler()
    scaler = torch.amp.GradScaler('cuda')
    # Resume training if checkpoint exists
    start_epoch = 0
    if checkpoint_path:
        model, optimizer, start_epoch = load_checkpoint(model, optimizer, checkpoint_path, device)

    for epoch in range(start_epoch, num_epochs):
        model.train()
        total_loss = 0
        optimizer.zero_grad()

        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device).float(), targets.to(device)

            # with torch.cuda.amp.autocast():
            #     outputs = model(inputs).view(-1, vocab_size)
            #     targets = targets.view(-1)
            #     loss = criterion(outputs, targets) / gradient_accumulation_steps
            with torch.amp.autocast('cuda'):  # ✅ Fixed autocast
                outputs = model(inputs).view(-1, vocab_size)
                targets = targets.view(-1)
                loss = criterion(outputs, targets) / 4  # Gradient accumulation steps

            scaler.scale(loss).backward()

            if (batch_idx + 1) % gradient_accumulation_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            total_loss += loss.item() * gradient_accumulation_steps

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")

        torch.save(model, 'model.pth')
        # # Save checkpoint
        # torch.save({
        #     'epoch': epoch,
        #     'model_state_dict': model.state_dict(),
        #     'optimizer_state_dict': optimizer.state_dict(),
        #     'loss': avg_loss,
        # }, 'checkpoint.pt')

    print("Training complete!")


In [10]:
# Setup and training
# d_model = 384  # MiniLM embedding size
# num_heads = 8
# hidden_dim = 1536
# num_layers = 6
# vocab_size = tokenizer.vocab_size
# max_len = 500

# d_model=1024 #len(embedding_model.encode("g"))
# num_heads=32
# hidden_dim=4096 #16384
# num_layers=24
vocab_size=tokenizer.vocab_size
max_len = 1000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

d_model = 1024
num_heads = 32
num_layers = 24
hidden_dim = 12288 # hidden_dim = 16384


model = Transformer(d_model, num_heads, hidden_dim, num_layers, vocab_size, max_len).to(device)

In [29]:
#model = torch.load('model.pth', weights_only=False)


In [15]:
# from transformers.optimization import Adafactor

# optimizer = Adafactor(
#         model.parameters(),
#         eps=(1e-30, 1e-3),
#         clip_threshold=1.0,
#         decay_rate=-0.8,
#         beta1=None,
#         weight_decay=0.01,
#         scale_parameter=True,
#         relative_step=True,  # Use relative step sizing
#         warmup_init=True     # Enable warmup
#     )

In [16]:
# checkpoint = torch.load('checkpoint.pt')

# # Restore model and optimizer states
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# # Restore epoch and loss
# start_epoch = checkpoint['epoch'] + 1  # Continue from the next epoch
# loss = checkpoint['loss']

# print(f"Resuming training from epoch {start_epoch} with loss {loss:.4f}")

Resuming training from epoch 2 with loss 6.3275


In [29]:
optimize_memory()

In [13]:
train_dataset = TextDataset(df, max_len=max_len)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)

                                                                                                                       

In [30]:
# import pickle

# # Save train_dataset to a file
# with open("H:/GAMES/train_dataset.pkl", "wb") as f:
#     pickle.dump(train_dataset, f)

# print("Dataset saved successfully!")


OSError: [Errno 28] No space left on device

In [14]:
import sys
import torch
from pympler import asizeof

# Check memory size of train_dataset
train_dataset_size = sys.getsizeof(train_dataset)  # Shallow size
train_dataset_deep_size = asizeof.asizeof(train_dataset)  # Deep size

# Check memory size of train_loader
train_loader_size = sys.getsizeof(train_loader)  # Shallow size
train_loader_deep_size = asizeof.asizeof(train_loader)  # Deep size

print(f"Shallow size of train_dataset: {train_dataset_size / (1024**2):.2f} MB")
print(f"Deep size of train_dataset: {train_dataset_deep_size / (1024**2):.2f} MB")

print(f"Shallow size of train_loader: {train_loader_size / (1024**2):.2f} MB")
print(f"Deep size of train_loader: {train_loader_deep_size / (1024**2):.2f} MB")


Shallow size of train_dataset: 0.00 MB
Deep size of train_dataset: 59.96 MB
Shallow size of train_loader: 0.00 MB
Deep size of train_loader: 59.97 MB


In [17]:
torch.save(model, 'model.pth')

In [15]:
# Create dataset and dataloader
# train_dataset = TextDataset(df, max_len=max_len)
# train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
# Train the model
train_model(model, train_loader, num_epochs=40, device=device, vocab_size=vocab_size)

  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)


Epoch 1, Batch 0, Loss: 10.8133, Avg Loss: 10.8133
GPU Memory: 7976.2MB / 14536.0MB
Epoch 1, Batch 50, Loss: 10.7034, Avg Loss: 10.7819
GPU Memory: 7982.7MB / 15764.0MB
Epoch 1, Batch 100, Loss: 10.3118, Avg Loss: 10.6597
GPU Memory: 7983.7MB / 15764.0MB
Epoch 1, Batch 150, Loss: 9.8308, Avg Loss: 10.4764
GPU Memory: 7982.7MB / 15764.0MB
Epoch 1, Batch 200, Loss: 9.4193, Avg Loss: 10.2557
GPU Memory: 7983.7MB / 15764.0MB
Epoch 1, Batch 250, Loss: 9.0225, Avg Loss: 10.0444
GPU Memory: 7982.7MB / 15764.0MB
Epoch 1, Batch 300, Loss: 8.6639, Avg Loss: 9.8520
GPU Memory: 7983.7MB / 15764.0MB
Epoch 1, Batch 350, Loss: 8.5891, Avg Loss: 9.6850
GPU Memory: 7982.7MB / 15764.0MB
Epoch 1, Batch 400, Loss: 8.4636, Avg Loss: 9.5427
GPU Memory: 7983.7MB / 15764.0MB
Epoch 1, Batch 450, Loss: 8.3328, Avg Loss: 9.4146
GPU Memory: 7982.7MB / 15764.0MB
Epoch 1, Average Loss: 9.2995
Epoch 2, Batch 0, Loss: 8.1870, Avg Loss: 8.1870
GPU Memory: 7983.7MB / 15764.0MB
Epoch 2, Batch 50, Loss: 8.0836, Avg Loss:

KeyboardInterrupt: 

In [25]:
train_model(model, train_loader,start_epoch=50, num_epochs=40, device=device, vocab_size=vocab_size)

  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():


Epoch 1, Batch 0, Loss: 5.6266, Avg Loss: 5.6266
GPU Memory: 8631.4MB / 14306.0MB
Epoch 1, Batch 50, Loss: 5.4901, Avg Loss: 5.5854
GPU Memory: 8639.1MB / 16752.0MB
Epoch 1, Batch 100, Loss: 5.5246, Avg Loss: 5.5633
GPU Memory: 8639.1MB / 16752.0MB
Epoch 1, Batch 150, Loss: 5.4203, Avg Loss: 5.5403
GPU Memory: 8639.1MB / 16752.0MB
Epoch 1, Batch 200, Loss: 5.3766, Avg Loss: 5.5251
GPU Memory: 8639.1MB / 16752.0MB
Epoch 1, Batch 250, Loss: 5.3213, Avg Loss: 5.5017
GPU Memory: 8639.1MB / 16752.0MB
Epoch 1, Batch 300, Loss: 5.3774, Avg Loss: 5.4935
GPU Memory: 8639.1MB / 16752.0MB
Epoch 1, Batch 350, Loss: 5.5819, Avg Loss: 5.4801
GPU Memory: 8639.1MB / 16752.0MB
Epoch 1, Batch 400, Loss: 5.5981, Avg Loss: 5.4730
GPU Memory: 8639.1MB / 16752.0MB
Epoch 1, Batch 450, Loss: 5.2984, Avg Loss: 5.4638
GPU Memory: 8639.1MB / 16752.0MB
Epoch 1, Average Loss: 5.4629
Epoch 2, Batch 0, Loss: 5.4481, Avg Loss: 5.4481
GPU Memory: 8639.1MB / 16752.0MB
Epoch 2, Batch 50, Loss: 5.5456, Avg Loss: 5.4185
G

KeyboardInterrupt: 

In [27]:
model

Transformer(
  (pos_encoding): PositionalEncoding()
  (layers): ModuleList(
    (0-23): 24 x TransformerBlock(
      (attn): MultiHeadAttention(
        (qkv_proj): Linear(in_features=1024, out_features=3072, bias=True)
        (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
      )
      (ff): FeedForward(
        (fc1): Linear(in_features=1024, out_features=12288, bias=True)
        (fc2): Linear(in_features=12288, out_features=1024, bias=True)
      )
      (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    )
  )
  (output_layer): Linear(in_features=1024, out_features=32000, bias=True)
  (input_projection): Linear(in_features=1024, out_features=1024, bias=True)
)

In [36]:
train_model(model, train_loader,start_epoch=50, num_epochs=100, device=device, vocab_size=vocab_size)

  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():


Epoch 51, Batch 0, Loss: 4.7486, Avg Loss: 4.7486
Epoch 51, Batch 50, Loss: 4.7525, Avg Loss: 4.6951
Epoch 51, Batch 100, Loss: 4.6004, Avg Loss: 4.6840
Epoch 51, Batch 150, Loss: 4.4614, Avg Loss: 4.6691
Epoch 51, Batch 200, Loss: 4.7306, Avg Loss: 4.6544
Epoch 51, Batch 250, Loss: 4.5479, Avg Loss: 4.6441
Epoch 51, Batch 300, Loss: 4.4485, Avg Loss: 4.6309
Epoch 51, Batch 350, Loss: 4.5122, Avg Loss: 4.6185
Epoch 51, Batch 400, Loss: 4.5297, Avg Loss: 4.6066
Epoch 51, Batch 450, Loss: 4.5123, Avg Loss: 4.5976
Epoch 51, Average Loss: 4.5903
Epoch 52, Batch 0, Loss: 4.5649, Avg Loss: 4.5649
Epoch 52, Batch 50, Loss: 4.5381, Avg Loss: 4.5005
Epoch 52, Batch 100, Loss: 4.5162, Avg Loss: 4.4990
Epoch 52, Batch 150, Loss: 4.3415, Avg Loss: 4.4908
Epoch 52, Batch 200, Loss: 4.4661, Avg Loss: 4.4884
Epoch 52, Batch 250, Loss: 4.3692, Avg Loss: 4.4885
Epoch 52, Batch 300, Loss: 4.3686, Avg Loss: 4.4830
Epoch 52, Batch 350, Loss: 4.4119, Avg Loss: 4.4837
Epoch 52, Batch 400, Loss: 4.3464, Avg 

In [18]:
model

Transformer(
  (pos_encoding): PositionalEncoding()
  (layers): ModuleList(
    (0-79): 80 x TransformerBlock(
      (attn): MultiHeadAttention(
        (qkv_proj): Linear(in_features=384, out_features=1152, bias=True)
        (out_proj): Linear(in_features=384, out_features=384, bias=True)
      )
      (ff): FeedForward(
        (fc1): Linear(in_features=384, out_features=4000, bias=True)
        (fc2): Linear(in_features=4000, out_features=384, bias=True)
      )
      (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
    )
  )
  (output_layer): Linear(in_features=384, out_features=32000, bias=True)
  (input_projection): Linear(in_features=384, out_features=384, bias=True)
)

In [7]:
# Create dataset and dataloader
train_dataset = TextDataset(df, max_len=max_len)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

# Train the model
train_model(model, train_loader, num_epochs=3, device=device, vocab_size=vocab_size)

  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)


Epoch 1, Batch 0, Loss: 12.0280, Avg Loss: 12.0280
GPU Memory: 2603.6MB / 3276.0MB
Epoch 1, Batch 10, Loss: 12.0386, Avg Loss: 12.1386
GPU Memory: 2610.0MB / 4282.0MB
Epoch 1, Batch 20, Loss: 11.8560, Avg Loss: 12.1141
GPU Memory: 2612.2MB / 4282.0MB
Epoch 1, Batch 30, Loss: 12.1311, Avg Loss: 12.1016
GPU Memory: 2612.2MB / 4282.0MB
Epoch 1, Batch 40, Loss: 12.0026, Avg Loss: 12.0856
GPU Memory: 2610.0MB / 4282.0MB
Epoch 1, Batch 50, Loss: 11.9014, Avg Loss: 12.0653
GPU Memory: 2610.0MB / 4282.0MB
Epoch 1, Batch 60, Loss: 11.7528, Avg Loss: 12.0317
GPU Memory: 2612.2MB / 4282.0MB
Epoch 1, Batch 70, Loss: 11.6829, Avg Loss: 11.9879
GPU Memory: 2612.2MB / 4282.0MB
Epoch 1, Batch 80, Loss: 11.7992, Avg Loss: 11.9480
GPU Memory: 2610.0MB / 4282.0MB
Epoch 1, Batch 90, Loss: 11.3347, Avg Loss: 11.9002
GPU Memory: 2610.0MB / 4282.0MB
Epoch 1, Batch 100, Loss: 11.4087, Avg Loss: 11.8494
GPU Memory: 2612.2MB / 4282.0MB
Epoch 1, Batch 110, Loss: 11.1728, Avg Loss: 11.7848
GPU Memory: 2612.2MB / 

In [None]:
import torch

# Define the checkpoint file
checkpoint_path = "checkpoint.pt"

# Load the checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)

# Load only the model weights
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()  # Set model to evaluation mode

print("Model loaded successfully for inference!")

# Example: Run inference
with torch.no_grad():
    input_tensor = torch.randn(1, 1024).to(device)  # Replace with actual input
    output = model(input_tensor)
    print(output)


In [34]:
texts = ["WRITE A STORY ", "STORY"]
embeddings = get_embeddings(texts).unsqueeze(0).to(device)
output = model(embeddings)
decoded_texts = model.decode(output)
print(decoded_texts)

['is']


In [19]:
import torch
from typing import List, Union

class TransformerInference:
    def __init__(self, model, tokenizer, embedding_model, device=None):
        self.model = model
        self.tokenizer = tokenizer
        self.embedding_model = embedding_model
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()

    def get_embeddings(self, text: Union[str, List[str]]) -> torch.Tensor:
        """Convert input text to embeddings."""
        if isinstance(text, str):
            text = [text]
        embeddings = torch.tensor(
            self.embedding_model.encode(text), 
            dtype=torch.float32
        )
        return embeddings

    def pad_embeddings(self, embeddings: torch.Tensor, max_len: int = 100) -> torch.Tensor:
        """Pad or truncate embeddings to specified length."""
        batch_size, seq_len, emb_dim = embeddings.shape
        if seq_len < max_len:
            padding = torch.zeros(batch_size, max_len - seq_len, emb_dim, device=embeddings.device)
            return torch.cat([embeddings, padding], dim=1)
        return embeddings[:, :max_len, :]

    # @torch.no_grad()
    # def generate(
    #     self,
    #     text: Union[str, List[str]],
    #     max_length: int = 100,
    #     temperature: float = 1.0,
    #     top_k: int = 50,
    #     top_p: float = 0.9,
    # ) -> List[str]:
    #     """
    #     Generate text from input prompt.
        
    #     Args:
    #         text: Input text or list of texts
    #         max_length: Maximum length of generated sequence
    #         temperature: Sampling temperature (1.0 = normal, <1.0 = more focused, >1.0 = more random)
    #         top_k: Number of highest probability tokens to consider for sampling
    #         top_p: Cumulative probability threshold for nucleus sampling
        
    #     Returns:
    #         List of generated texts
    #     """
    #     # Prepare input
    #     embeddings = self.get_embeddings(text)
    #     embeddings = embeddings.unsqueeze(0) if len(embeddings.shape) == 2 else embeddings
    #     embeddings = embeddings.to(self.device)
    #     embeddings = self.pad_embeddings(embeddings)
        
    #     # Generate
    #     generated_ids = []
        
    #     # Forward pass through model
    #     logits = self.model(embeddings)
        
    #     # Apply temperature
    #     logits = logits / temperature
        
    #     # Apply top-k filtering
    #     if top_k > 0:
    #         indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
    #         logits[indices_to_remove] = float('-inf')
        
    #     # Apply nucleus (top-p) filtering
    #     if top_p < 1.0:
    #         sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    #         cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
    #         sorted_indices_to_remove = cumulative_probs > top_p
    #         sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    #         sorted_indices_to_remove[..., 0] = 0
    #         indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
    #         logits[indices_to_remove] = float('-inf')
        
    #     # Sample from the filtered distribution
    #     probs = torch.softmax(logits, dim=-1)
    #     next_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1)
    #     generated_ids.append(next_tokens)
        
    #     # Decode and return results
    #     generated_ids = torch.cat(generated_ids, dim=-1)
    #     generated_texts = [
    #         self.tokenizer.decode(ids, skip_special_tokens=True)
    #         for ids in generated_ids
    #     ]
        
    #     return generated_texts


    # multiple words.
    @torch.no_grad()
    def generate(
        self,
        text: Union[str, List[str]],
        max_length: int = 100,
        temperature: float = 1.0,
        top_k: int = 50,
        top_p: float = 0.9,
    ) -> List[str]:
        """
        Generate text from input prompt.
        """
        # Prepare input
        embeddings = self.get_embeddings(text)
        embeddings = embeddings.unsqueeze(0) if len(embeddings.shape) == 2 else embeddings
        embeddings = embeddings.to(self.device)
        embeddings = self.pad_embeddings(embeddings)
    
        generated_ids = []
    
        for _ in range(max_length):  # Iterate to generate more tokens
            logits = self.model(embeddings)[:, -1, :]  # Only take last token logits
            
            # Apply temperature
            logits = logits / temperature
    
            # Apply top-k filtering
            if top_k > 0:
                values, indices = torch.topk(logits, top_k)
                min_values = values[..., -1, None]
                logits[logits < min_values] = float('-inf')
    
            # Apply nucleus (top-p) filtering
            if top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                logits[indices_to_remove] = float('-inf')
    
            # Sample next token
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
    
            # Append generated token
            generated_ids.append(next_token)
    
            # Convert token ID to embeddings and feed back into the model
            token_embedding = self.embedding_model.encode([self.tokenizer.decode(next_token.item())], convert_to_tensor=True)
            token_embedding = token_embedding.unsqueeze(0).to(self.device)
            embeddings = torch.cat([embeddings, token_embedding], dim=1)
    
        # Decode generated tokens
        generated_ids = torch.cat(generated_ids, dim=-1)
        generated_texts = [
            self.tokenizer.decode(ids.tolist(), skip_special_tokens=True)
            for ids in generated_ids
        ]
    
        return generated_texts

    def __call__(self, *args, **kwargs):
        return self.generate(*args, **kwargs)

# Example usage
if __name__ == "__main__":
    # Assuming model, tokenizer, and embedding_model are already defined
    inference = TransformerInference(model, tokenizer, embedding_model)
    
    # Single text generation
    prompt = "Write a short story about"
    generated_text = inference(prompt)
    print(f"Input: {prompt}")
    print(f"Generated: {generated_text[0]}")
    
    # Batch generation
    prompts = [
        "Write a poem about",
        "Explain how to",
        "Tell me about"
    ]
    generated_texts = inference(
        prompts,
        max_length=150,
        temperature=0.8,
        top_k=40,
        top_p=0.9
    )
    
    for prompt, generated in zip(prompts, generated_texts):
        print(f"\nInput: {prompt}")
        print(f"Generated: {generated}")



Input: Write a short story about
Generated: ?
 forAss?<<thAss><>istantistantthAss><OkinkthOkink I>><th howink I'ay toAss I of? need for toOk' I>Ok>< to?>> fromink me need do what downink Hmm usinginks down to that LetAss start. I start off' know step I. But
 The
 I.Ok start
?

 can  to that The
 it start be haves by

Input: Write a poem about
Generated: 
Ass?
 howAssAssAss?thAssAss>>thinkOkOkink>thAss><><Okth to soayistant howinkay
Okink I so so><Ok how
 figure> Hmms makeOk> howink Lets to I to have Hmm using start byth down' I Let Let do start is me Let Hmm what? start?. means have I, down it
 to have? off. which which is it a be
 But to down.
. can it, is? down some

. because1 can. is. it But it So to, like in? can means down
 the is' for. So it That to might.

 to'
 That


In [13]:
@torch.no_grad()
def generate(
    self,
    text: Union[str, List[str]],
    max_length: int = 100,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.9,
) -> List[str]:
    """
    Generate multiple words instead of just one.
    """
    # Convert input text to token IDs
    input_ids = self.tokenizer.encode(text, return_tensors="pt").to(self.device)

    for _ in range(max_length):
        # Convert token IDs to embeddings
        embeddings = self.get_embeddings(self.tokenizer.decode(input_ids[0]))

        # Ensure correct shape
        embeddings = embeddings.unsqueeze(0).to(self.device)
        embeddings = self.pad_embeddings(embeddings)

        # Forward pass
        logits = self.model(embeddings)[:, -1, :]  # Get last token logits
        
        # Apply temperature scaling
        logits = logits / temperature

        # Apply top-k filtering
        if top_k > 0:
            values, indices = torch.topk(logits, top_k)
            min_values = values[..., -1, None]
            logits[logits < min_values] = float('-inf')

        # Apply nucleus (top-p) filtering
        if top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
            logits[indices_to_remove] = float('-inf')

        # Sample the next token
        probs = torch.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        # Append token to input_ids
        input_ids = torch.cat([input_ids, next_token], dim=-1)

        # Stop if EOS token is generated
        if next_token.item() == self.tokenizer.eos_token_id:
            break

    # Decode final text
    generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return [generated_text]


# Example usage
if __name__ == "__main__":
    # Assuming model, tokenizer, and embedding_model are already defined
    inference = TransformerInference(model, tokenizer, embedding_model)
    
    # Single text generation
    prompt = "Write a short story about"
    generated_text = inference(prompt)
    print(f"Input: {prompt}")
    print(f"Generated: {generated_text[0]}")
    
    # Batch generation
    prompts = [
        "Write a poem about",
        "Explain how to",
        "Tell me about"
    ]
    generated_texts = inference(
        prompts,
        max_length=150,
        temperature=0.8,
        top_k=40,
        top_p=0.9
    )
    
    for prompt, generated in zip(prompts, generated_texts):
        print(f"\nInput: {prompt}")
        print(f"Generated: {generated}")

Input: Write a short story about
Generated: Brow to that of a. to. to.. expression it is eterIll VI VerIll  isété is a Artikel. of. that.:秀 . entfer. the of ,.0 ben Brow-ά the. elements Based the elements b simultaneously children. step tov Verété, step '.util that. the.. So in bazie Ver0 a voices. is  is that. thatété in in.. brick ers. is, happ' in is of

Input: Write a poem about
Generated: voices . Zar is秀....,.. VideosIll is is, iseter is, Brow aIll.
 the на toIll на is объек. is is秀.. Ver Brow
 to matplotlib. elements...
.
.. is- step 
.- the to
 is.... simultaneously.,0 to. на. is is....0, in ..
 al. So.... it..ally and... step is console the, in is is.. is.. the..
 crown is,.. the,.
,  toally   the
 
 thatutil So.....


In [25]:
# Load your trained model and other components
inference = TransformerInference(model, tokenizer, embedding_model)

# Generate from a single prompt
text = inference("Write a story about", max_length=100, temperature=0.8)

# Or generate from multiple prompts
texts = inference([
    "Write a story about",
    "Explain how to"
], max_length=100)

In [26]:
texts

['då방深 mentions tasks another " Each tasks mentions ofpressionној an\nargv深深 developing深 mentions perhaps of mut abouttest many one mentions aboutEND easy\n phone\' music many phone music\' theној music Each about music " musicyle could Smith pending. anotherној Each could mentions mentions Each Smith He about especiallyној2ној about Eachној of music manyx\n.emplo Smith Smith especially\nblem " a aboutpow\n so about I many step0 another\' mut mentions I perhaps mentions']

In [13]:
def count_parameters(d_model=768, num_heads=12, hidden_dim=3072, num_layers=12, vocab_size=32000):
    """Calculate parameter count for each component of the transformer model"""
    
    params = {
        "Input Projection Layer": {
            "weight": d_model * d_model,
            "bias": d_model
        },
        
        "Positional Encoding": {
            "parameters": 0  # No trainable parameters
        },
        
        "Per Layer": {
            "Self-Attention": {
                "qkv_proj": {
                    "weight": d_model * (d_model * 3),
                    "bias": d_model * 3
                },
                "out_proj": {
                    "weight": d_model * d_model,
                    "bias": d_model
                }
            },
            "Feed Forward": {
                "fc1": {
                    "weight": d_model * hidden_dim,
                    "bias": hidden_dim
                },
                "fc2": {
                    "weight": hidden_dim * d_model,
                    "bias": d_model
                }
            },
            "Layer Norm (2x)": {
                "weight": d_model * 2,
                "bias": d_model * 2
            }
        },
        
        "Final Layer Norm": {
            "weight": d_model,
            "bias": d_model
        },
        
        "Output Layer": {
            "weight": d_model * vocab_size,
            "bias": vocab_size
        }
    }
    
    # Calculate total parameters per layer
    params_per_layer = (
        # Self-attention
        (d_model * (d_model * 3) + d_model * 3) +  # qkv_proj
        (d_model * d_model + d_model) +            # out_proj
        # Feed forward
        (d_model * hidden_dim + hidden_dim) +      # fc1
        (hidden_dim * d_model + d_model) +         # fc2
        # Layer norms
        (d_model * 2 + d_model * 2)               # 2 layer norms with weights and biases
    )
    
    # Calculate total parameters
    total_params = (
        # Input projection
        (d_model * d_model + d_model) +
        # All transformer layers
        (params_per_layer * num_layers) +
        # Final layer norm
        (d_model * 2) +
        # Output layer
        (d_model * vocab_size + vocab_size)
    )
    
    return {
        "Parameters per layer": params_per_layer,
        "Total parameters": total_params,
        "Parameters in millions": total_params / 1_000_000
    }

# Calculate parameters with current configuration
results = count_parameters(d_model=768, num_heads=12, hidden_dim=16384, num_layers=80, vocab_size=32000)

print(f"Parameters per transformer layer: {results['Parameters per layer']:,}")
print(f"Total parameters: {results['Total parameters']:,}")
print(f"Total parameters in millions: {results['Parameters in millions']:.2f}M")

Parameters per transformer layer: 27,548,416
Total parameters: 2,229,073,408
Total parameters in millions: 2229.07M


In [10]:
# Original parameters
original_params = {
    "d_model": 768,      # Keep this from embedding model
    "num_heads": 12,
    "hidden_dim": 3072,
    "num_layers": 12,
    "vocab_size": 32000
}

# Scaled up parameters (closer to LLaMA scale)
scaled_params = {
    "d_model": 768,              # Kept from embedding model
    "num_heads": 32,             # Increased from 12
    "hidden_dim": 16384,         # Significantly increased from 3072
    "num_layers": 80,            # Significantly increased from 12
    "vocab_size": 32000,         # LLaMA uses 32k vocab
    "max_len": 2048,             # Increased context length
    "dropout_rate": 0.1
}

def calculate_parameter_count(params):
    d_model = params["d_model"]
    hidden_dim = params["hidden_dim"]
    num_layers = params["num_layers"]
    vocab_size = params["vocab_size"]
    
    # Parameters per layer
    attention_params = (d_model * (d_model * 3) + d_model * 3) + (d_model * d_model + d_model)
    ffn_params = (d_model * hidden_dim + hidden_dim) + (hidden_dim * d_model + d_model)
    layer_norm_params = 4 * d_model  # 2 layer norms per layer
    
    params_per_layer = attention_params + ffn_params + layer_norm_params
    
    # Total parameters
    total_params = (
        (d_model * d_model + d_model) +  # Input projection
        (params_per_layer * num_layers) + # All transformer layers
        (d_model * 2) +                   # Final layer norm
        (d_model * vocab_size + vocab_size) # Output layer
    )
    
    return total_params / 1_000_000  # Return in millions

print(f"Original model parameters: {calculate_parameter_count(original_params):.2f}M")
print(f"Scaled model parameters: {calculate_parameter_count(scaled_params):.2f}M")

# Modified training setup for the scaled model
def get_scaled_training_config():
    return {
        "batch_size": 8,                    # Reduced due to model size
        "gradient_accumulation_steps": 32,   # Increased for effective larger batch
        "learning_rate": 1e-4,
        "warmup_steps": 2000,
        "weight_decay": 0.1,
        "max_grad_norm": 1.0,
        "num_epochs": 3,
        "optimizer": "AdaFactor",           # Changed from AdamW for memory efficiency
        "lr_scheduler": "cosine_with_warmup"
    }

# Example training configuration with optimizations for large scale
training_config = """
# Training setup for scaled model
optimizer = transformers.Adafactor(
    model.parameters(),
    lr=1e-4,
    eps=(1e-30, 1e-3),
    clip_threshold=1.0,
    decay_rate=-0.8,
    beta1=None,
    scale_parameter=True,
    relative_step=True,
    warmup_init=True
)

# Gradient accumulation setup
gradient_accumulation_steps = 32
model.gradient_checkpointing_enable()  # Enable gradient checkpointing

# Use mixed precision training
scaler = torch.cuda.amp.GradScaler()

# Training loop with optimizations
for epoch in range(num_epochs):
    for i, batch in enumerate(train_loader):
        with torch.cuda.amp.autocast():
            loss = model(batch) / gradient_accumulation_steps
            scaler.scale(loss).backward()
            
        if (i + 1) % gradient_accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
"""

Original model parameters: 110.25M
Scaled model parameters: 2229.07M


In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import numpy as np

# Load LLaMA tokenizer and MiniLM embeddings
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

def get_embeddings(texts):
    return torch.tensor(embedding_model.encode(texts), dtype=torch.float32)

def decode_nearest_embedding(embeddings):
    all_tokens = tokenizer.convert_ids_to_tokens(list(range(tokenizer.vocab_size)))
    all_embeddings = get_embeddings(all_tokens)
    similarities = torch.cdist(embeddings, all_embeddings)
    closest_indices = similarities.argmin(dim=-1)
    return [tokenizer.decode(idx) for idx in closest_indices]

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)

# Multi-Head Self-Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.qkv_proj = nn.Linear(d_model, d_model * 3)
        self.out_proj = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv_proj(x).reshape(B, T, 3, self.num_heads, self.d_k).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn_weights = (q @ k.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_output = (attn_weights @ v).transpose(1, 2).reshape(B, T, C)
        return self.out_proj(attn_output)

# Feed-Forward Network
class FeedForward(nn.Module):
    def __init__(self, d_model, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(d_model, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, d_model)
    
    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))

# Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, hidden_dim):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, hidden_dim)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
    
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x

# Full Transformer Model
class Transformer(nn.Module):
    def __init__(self, d_model, num_heads, hidden_dim, num_layers, max_len=5000):
        super().__init__()
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([TransformerBlock(d_model, num_heads, hidden_dim) for _ in range(num_layers)])
        self.output_layer = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        x = self.pos_encoding(x)
        for layer in self.layers:
            x = layer(x)
        return self.output_layer(x)

# Example usage
d_model = 384  # MiniLM embedding size
num_heads = 8
hidden_dim = 1536
num_layers = 6
max_len = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Transformer(d_model, num_heads, hidden_dim, num_layers, max_len).to(device)
texts = ["Hello, world!", "How are you?"]
embeddings = get_embeddings(texts).unsqueeze(0).to(device)
output = model(embeddings)
decoded_text = decode_nearest_embedding(output.squeeze(0))
print("Decoded output:", decoded_text)


RuntimeError: X1 and X2 must have the same device type. X1: cuda X2: cpu