In [1]:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118


Looking in indexes: https://download.pytorch.org/whl/cu118
INFO: pip is looking at multiple versions of torch to determine which version is compatible with other requirements. This could take a while.
Collecting torch
  Downloading https://download.pytorch.org/whl/cu118/torch-2.7.1%2Bcu118-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (28 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading https://download.pytorch.org/whl/sympy-1.13.3-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.8.89 (from torch)
  Downloading https://download.pytorch.org/whl/cu118/nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-manylinux1_x86_64.whl (23.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.2/23.2 MB[0m [31m72.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu11==11.8.89 (from torch)
  Downloading https://download.pytorch.org/whl/cu118/nvidia_cuda_runtime_cu11-11.8.89-py3-none-manylinux1_x86_64.whl (875 kB)
[2K     [90m━━━━━

In [2]:
import torch
assert torch.cuda.is_available(), "GPU not detected"


In [3]:
import torch.nn as nn
import torch.nn.functional as F

class SimpleExpert(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)


In [4]:
class TopKRouter(nn.Module):
    def __init__(self, dim, num_experts, k=8):
        super().__init__()
        self.k = k
        self.gate = nn.Linear(dim, num_experts)
    def forward(self, x):
        # x: (B, T, D)
        logits = self.gate(x)                           # (B, T, E)
        topk_vals, topk_idx = torch.topk(logits, self.k, dim=-1)
        mask = torch.full_like(logits, float("-inf"))
        sparse = mask.scatter(-1, topk_idx, topk_vals)
        weights = F.softmax(sparse, dim=-1)             # (B, T, E)
        return weights, topk_idx


In [5]:
class MoELayer(nn.Module):
    def __init__(self, dim, hidden_dim, num_experts, k):
        super().__init__()
        self.router  = TopKRouter(dim, num_experts, k)
        self.experts = nn.ModuleList(
            SimpleExpert(dim, hidden_dim) for _ in range(num_experts)
        )
    def forward(self, x):
        B, T, D = x.shape
        w, idx = self.router(x)                 # w: (B,T,E), idx: (B,T,k)
        out = torch.zeros_like(x)
        flat_x = x.view(-1, D)                  # (B·T, D)
        flat_w = w.view(-1, w.size(-1))         # (B·T, E)

        for e, expert in enumerate(self.experts):
            mask = (idx == e).any(-1).view(-1)  # tokens routed to expert e
            if mask.sum() == 0: 
                continue
            inp = flat_x[mask]                          # select tokens
            res = expert(inp)                           # expert output
            w_e = flat_w[mask, e].unsqueeze(-1)         # weights
            out.view(-1, D)[mask] += res * w_e          # weighted sum

        return out


In [6]:
class MoETransformerBlock(nn.Module):
    def __init__(self, dim, heads, num_experts, k):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.ln2 = nn.LayerNorm(dim)
        self.moe = MoELayer(dim, dim*4, num_experts, k)

    def forward(self, x):
        # Self-Attention
        resid = x
        x = self.ln1(x)
        attn_out, _ = self.attn(x, x, x)
        x = resid + attn_out

        # MoE instead of FFN
        resid = x
        x = self.ln2(x)
        x = resid + self.moe(x)
        return x


In [7]:
class MoELM(nn.Module):
    def __init__(self, vocab_size, dim, depth, heads, num_experts, k, max_len=2048):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, dim)
        self.pos_emb   = nn.Embedding(max_len, dim)
        self.layers    = nn.ModuleList(
            MoETransformerBlock(dim, heads, num_experts, k)
            for _ in range(depth)
        )
        self.ln_f      = nn.LayerNorm(dim)
        self.head      = nn.Linear(dim, vocab_size)

    def forward(self, ids, targets=None):
        B, T = ids.shape
        pos = torch.arange(T, device=ids.device).unsqueeze(0)
        x = self.token_emb(ids) + self.pos_emb(pos)
        for block in self.layers:
            x = block(x)
        x = self.ln_f(x)
        logits = self.head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1)
            )
        return logits, loss

    @torch.no_grad()
    def generate(self, ids, max_new_tokens, temp=1.0):
        for _ in range(max_new_tokens):
            logits, _ = self(ids)
            probs = F.softmax(logits[:, -1, :]/temp, dim=-1)
            next_id = torch.multinomial(probs, 1)
            ids = torch.cat([ids, next_id], dim=1)
        return ids


In [8]:
model = MoELM(
    vocab_size=50000, dim=256,
    depth=6, heads=4,
    num_experts=16, k=2,
    max_len=128
).cuda()


In [9]:
import torch

batch_size = 1
seq_len = 32  # Try something ≤ max_len to avoid OOM
vocab_size = 50000

# Random dummy input for demonstration:
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device='cuda')
targets = torch.randint(0, vocab_size, (batch_size, seq_len), device='cuda')


In [10]:
logits, loss = model(input_ids, targets)
print("Logits shape:", logits.shape)  # Should be [1, 32, 50000]
print("Loss:", loss.item())


Logits shape: torch.Size([1, 32, 50000])
Loss: 11.04063892364502


In [11]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
model.train()
for step in range(100):  # Replace with your number of steps
    logits, loss = model(input_ids, targets)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(f"Step {step}, Loss: {loss.item():.4f}")


Step 0, Loss: 11.0362
Step 1, Loss: 6.5686
Step 2, Loss: 4.1262
Step 3, Loss: 2.7558
Step 4, Loss: 2.0693
Step 5, Loss: 1.7070
Step 6, Loss: 1.4837
Step 7, Loss: 1.2988
Step 8, Loss: 1.1431
Step 9, Loss: 0.9861
Step 10, Loss: 0.8338
Step 11, Loss: 0.6852
Step 12, Loss: 0.5614
Step 13, Loss: 0.4500
Step 14, Loss: 0.3599
Step 15, Loss: 0.2881
Step 16, Loss: 0.2301
Step 17, Loss: 0.1850
Step 18, Loss: 0.1482
Step 19, Loss: 0.1205
Step 20, Loss: 0.0971
Step 21, Loss: 0.0798
Step 22, Loss: 0.0660
Step 23, Loss: 0.0553
Step 24, Loss: 0.0464
Step 25, Loss: 0.0400
Step 26, Loss: 0.0346
Step 27, Loss: 0.0299
Step 28, Loss: 0.0261
Step 29, Loss: 0.0230
Step 30, Loss: 0.0204
Step 31, Loss: 0.0184
Step 32, Loss: 0.0166
Step 33, Loss: 0.0151
Step 34, Loss: 0.0139
Step 35, Loss: 0.0128
Step 36, Loss: 0.0118
Step 37, Loss: 0.0110
Step 38, Loss: 0.0104
Step 39, Loss: 0.0098
Step 40, Loss: 0.0093
Step 41, Loss: 0.0088
Step 42, Loss: 0.0085
Step 43, Loss: 0.0079
Step 44, Loss: 0.0076
Step 45, Loss: 0.00

In [12]:
prompt = torch.randint(0, 50000, (1, 16), device='cuda')
out = model.generate(prompt, max_new_tokens=16)
print(out)


tensor([[37258, 12709, 17479, 49563, 46406, 21212,  9187, 11848, 13444, 13654,
         44498, 12141, 41223, 32827, 45000, 16927, 10161, 48738, 21188, 15608,
         14018,  2678, 21677, 14752, 37555,  2301,  5504, 14917, 20963,  5188,
         14811, 40304]], device='cuda:0')


In [13]:
pip install transformers


Note: you may need to restart the kernel to use updated packages.


In [14]:
from transformers import GPT2Tokenizer

# Load a pre-trained tokenizer (vocab_size ~50,000)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# Decode your tensor
tensor = torch.tensor([[25654, 24315, 18540, 21630, 37387, 18774, 43277, 28239, 23174, 48593,
                        41090, 42431, 14597, 41380, 24736, 35003, 10681, 18598, 35566, 38812,
                        14147,  5208,  1674, 26094, 46819,  3678,  1674, 19421, 16492, 23932,
                        34899,  9181]], device='cuda:0')

text = tokenizer.decode(tensor[0].cpu().tolist())  # Move to CPU for decoding
print(text)  # Will produce text based on GPT-2's vocab (may look like gibberish if IDs don't match real words)


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

 Bard emailed multipickers Caldwell recruited LINStudents AndreaCalling Nicotineuskydalenatureconservancy cush drib unemployment descenthelial expulsion tablet appre foreBIT Frenchmanitute fore │ Pand toes INV cited


In [15]:
pip install datasets transformers


Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading fsspec-2025.3.0-py3-none-any.whl (193 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fsspec
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2025.5.1
    Uninstalling fsspec-2025.5.1:
      Successfully uninstalled fsspec-2025.5.1
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 2.8.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.
cesium 0.12.4 requires numpy<3.0,>=2.0, but you have numpy 1.26.4 which is incompatible.
gcsfs 2025.3.2 requires fsspec==2025.3.2, but you have fsspec 2025.3.0 which is incompatibl

In [16]:
import torch
print(torch.cuda.is_available())  # Should be True
print(torch.version.cuda)  # Should be 11.8 for cu118
print(torch.cuda.get_device_properties(0))  # Check T4 details
torch.cuda.empty_cache()  # Test alone
test_tensor = torch.rand(1, device='cuda')  # Simple allocation
print("Test passed if no error.")


True
11.8
_CudaDeviceProperties(name='Tesla T4', major=7, minor=5, total_memory=15095MB, multi_processor_count=40, uuid=593c882b-50aa-b01e-6f87-fb04087a0a34, L2_cache_size=4MB)
Test passed if no error.


In [17]:
# Updated Code for Stable MoE Model Training

# Imports and Setup
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import GPT2Tokenizer
from torch import optim
from torch.cuda.amp import autocast, GradScaler

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"
torch.backends.cuda.enable_mem_efficient_sdp(False)

torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Weight Initialization Function
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight, gain=nn.init.calculate_gain('relu'))
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Embedding):
        nn.init.normal_(m.weight, mean=0, std=0.01)  # Smaller std for stability

# Model Classes with Stability Fixes
class SimpleExpert(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
        self.apply(init_weights)

    def forward(self, x):
        return self.net(x)

class TopKRouter(nn.Module):
    def __init__(self, dim, num_experts, k=2):
        super().__init__()
        self.k = k
        self.gate = nn.Linear(dim, num_experts)
        self.apply(init_weights)

    def forward(self, x):
        logits = self.gate(x)
        logits = torch.clamp(logits, min=-10, max=10)  # Clip to prevent extremes
        topk_vals, topk_idx = torch.topk(logits, self.k, dim=-1)
        mask = torch.full_like(logits, float("-inf"))
        sparse = mask.scatter(-1, topk_idx, topk_vals)
        weights = F.softmax(sparse, dim=-1) + 1e-8  # Epsilon for stability
        return weights, topk_idx

class MoELayer(nn.Module):
    def __init__(self, dim, hidden_dim, num_experts, k):
        super().__init__()
        self.router = TopKRouter(dim, num_experts, k)
        self.experts = nn.ModuleList(SimpleExpert(dim, hidden_dim) for _ in range(num_experts))
        self.apply(init_weights)

    def forward(self, x):
        B, T, D = x.shape
        w, idx = self.router(x)
        out = torch.zeros_like(x)
        flat_x = x.view(-1, D)
        flat_w = w.view(-1, w.size(-1))
        for e, expert in enumerate(self.experts):
            mask = (idx == e).any(-1).view(-1)
            if mask.sum() == 0:
                continue
            inp = flat_x[mask]
            res = expert(inp)
            w_e = flat_w[mask, e].unsqueeze(-1)
            out.view(-1, D)[mask] += res * w_e
        return out

class MoETransformerBlock(nn.Module):
    def __init__(self, dim, heads, num_experts, k):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.ln2 = nn.LayerNorm(dim)
        self.moe = MoELayer(dim, dim*4, num_experts, k)
        self.apply(init_weights)

    def forward(self, x, attn_mask=None):
        resid = x
        x = self.ln1(x)
        key_padding_mask = (attn_mask == 0) if attn_mask is not None else None
        attn_out, _ = self.attn(x, x, x, key_padding_mask=key_padding_mask)
        x = resid + attn_out
        resid = x
        x = self.ln2(x)
        x = resid + self.moe(x)
        return x

class MoELM(nn.Module):
    def __init__(self, vocab_size, dim, depth, heads, num_experts, k, max_len=512):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, dim)
        self.pos_emb = nn.Embedding(max_len, dim)
        self.layers = nn.ModuleList(MoETransformerBlock(dim, heads, num_experts, k) for _ in range(depth))
        self.ln_f = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, vocab_size)
        self.apply(init_weights)

    def forward(self, ids, attention_mask=None, targets=None):
        B, T = ids.shape
        pos = torch.arange(0, T, dtype=torch.long, device=ids.device).unsqueeze(0).expand(B, T)
        x = self.token_emb(ids) + self.pos_emb(pos)
        for block in self.layers:
            x = block(x, attention_mask)
        x = self.ln_f(x)
        logits = self.head(x)
        if torch.isnan(logits).any() or torch.isinf(logits).any():
            print("Warning: NaN/Inf detected in logits!")
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=tokenizer.pad_token_id)
        return logits, loss

    @torch.no_grad()
    def generate(self, ids, attention_mask=None, max_new_tokens=50, temp=1.0):
        current_mask = attention_mask
        max_pos = self.pos_emb.num_embeddings
        for _ in range(max_new_tokens):
            if ids.shape[1] > max_pos:
                ids = ids[:, -max_pos:]
                if current_mask is not None:
                    current_mask = current_mask[:, -max_pos:]
            logits, _ = self(ids, current_mask)
            logits = torch.clamp(logits, min=-100, max=100)  # Clip to prevent NaN in softmax
            probs = F.softmax(logits[:, -1, :]/temp, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            ids = torch.cat([ids, next_id], dim=1)
            if current_mask is not None:
                current_mask = torch.cat([current_mask, torch.ones((current_mask.shape[0], 1), device=ids.device)], dim=1)
        return ids

# Load Tokenizer and Dataset
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:5%]")  # Using 5% as before

# Pre-filter empty text examples
def filter_empty(examples):
    return len(examples["text"].strip()) > 0

dataset = dataset.filter(filter_empty, batched=False)  # Non-batched filter for simplicity

# Tokenize without filtering in the function
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, max_length=128, padding="max_length")

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

# Post-tokenization filter for non-empty input_ids
def filter_tokenized(example):
    return len(example["input_ids"]) > 1 and sum(example["attention_mask"]) > 1  # Ensure not all padded

tokenized_dataset = tokenized_dataset.filter(filter_tokenized)

tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

dataloader = DataLoader(tokenized_dataset, batch_size=1, shuffle=True)
print(f"Filtered dataset size: {len(tokenized_dataset)} examples")


# Instantiate Model
actual_vocab_size = len(tokenizer)
model = MoELM(
    vocab_size=actual_vocab_size,
    dim=128,
    depth=2,
    heads=2,
    num_experts=4,
    k=2,
    max_len=512
).to(device)

# Mixed Precision and Optimizer with Weight Decay
scaler = GradScaler()
optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)  # Added decay

# Training Loop with NaN Handling
model.train()
num_epochs = 5  # Increased for better results
for epoch in range(num_epochs):
    total_loss = 0
    num_batches = 0
    skipped = 0
    for batch in dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        targets = input_ids[:, 1:].contiguous()
        input_ids = input_ids[:, :-1].contiguous()
        attention_mask = attention_mask[:, :-1].contiguous()
        
        with autocast():
            logits, loss = model(input_ids, attention_mask, targets)
        if loss is not None and not torch.isnan(loss):
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            total_loss += loss.item()
            num_batches += 1
        else:
            skipped += 1
            # Add small dummy loss for minimal update
            dummy_loss = torch.tensor(10.0, requires_grad=True).to(device)
            dummy_loss.backward()
            optimizer.step()
            optimizer.zero_grad()
    if num_batches > 0:
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss:.4f}, Skipped Batches: {skipped}")
    else:
        print(f"Epoch {epoch + 1}/{num_epochs}, All batches skipped - check data or initialization.")

# Generate Example
model.eval()
prompt_text = "Once upon a time"
prompt_ids = tokenizer.encode(prompt_text, return_tensors="pt").to(device)
prompt_mask = torch.ones_like(prompt_ids).to(device)

generated_ids = model.generate(prompt_ids, attention_mask=prompt_mask, max_new_tokens=50)
decoded_text = tokenizer.decode(generated_ids[0].cpu().tolist())
print("Generated Text:", decoded_text)


README.md: 0.00B [00:00, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1836 [00:00<?, ? examples/s]

Map:   0%|          | 0/1185 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1185 [00:00<?, ? examples/s]

Filtered dataset size: 1185 examples


  scaler = GradScaler()
  with autocast():


Epoch 1/5, Average Loss: 10.2089, Skipped Batches: 0
Epoch 2/5, Average Loss: 8.8860, Skipped Batches: 0
Epoch 3/5, Average Loss: 7.9837, Skipped Batches: 0
Epoch 4/5, Average Loss: 7.3568, Skipped Batches: 0
Epoch 5/5, Average Loss: 6.9697, Skipped Batches: 0
Generated Text: Once upon a time  generation seemed who round entertainment , in , Mus , Fred  conducted the the deice Kw the was ofly and to horizontally quick andation the systematic specified inner Oro Looking credit in in action way carrying and Mitchellane with@ notes cartridge Christina are
