In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.nn import CrossEntropyLoss
import math
from tqdm import tqdm
import csv

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def load_model(model_name):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        # use_flash_attention_2=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model("deepseek-ai/deepseek-moe-16b-base")
model.eval()
model.to(device)

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

DeepseekForCausalLM(
  (model): DeepseekModel(
    (embed_tokens): Embedding(102400, 2048)
    (layers): ModuleList(
      (0): DeepseekDecoderLayer(
        (self_attn): DeepseekSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): DeepseekRotaryEmbedding()
        )
        (mlp): DeepseekMLP(
          (gate_proj): Linear(in_features=2048, out_features=10944, bias=False)
          (up_proj): Linear(in_features=2048, out_features=10944, bias=False)
          (down_proj): Linear(in_features=10944, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): DeepseekRMSNorm()
        (post_attention_layernorm): DeepseekRMSNorm()
      )
      (1-27): 27 x DeepseekDecod

In [34]:
def calculate_perplexity(model, tokenizer, json_file_path, device=device, domain="code", top_k=2):
    """
    Calculate perplexity using the DeepSeek model with configurable number of experts per token.
    
    Args:
        model: The DeepSeek model
        tokenizer: The tokenizer to use
        json_file_path: Path to the input JSON file containing code samples
        device: The device to run on (cuda/cpu)
        domain: Domain name for output file and data key in JSON ("code", "text", etc)
        top_k: Number of experts to select per token (1-6)
    """
    # Validate top_k parameter
    if not 1 <= top_k <= 6:
        raise ValueError("top_k must be between 1 and 6")
    
    # Configure MoE layers to use specified number of experts
    for layer in model.model.layers:
        if hasattr(layer.mlp, 'experts'):  # Check if it's an MoE layer
            layer.mlp.num_experts_per_tok = top_k
            if hasattr(layer.mlp, 'gate'):
                layer.mlp.gate.top_k = top_k

    # Read JSON file and get samples for specified domain
    import json
    with open(json_file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
        samples = data[domain]  # Get samples for the specified domain

    # Calculate perplexity for each sample
    perplexities = []
    progress_bar = tqdm(samples, desc=f"Processing {domain} samples (top-k={top_k})")
    file_name = f"{domain}_perplexity_top{top_k}.csv"
    
    with open(file_name, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        if csvfile.tell() == 0:
            writer.writerow([f'{domain}_num', 'perplexity'])
            
        for i, sample in enumerate(progress_bar):
            if not sample.strip():
                continue
                
            # Tokenize sample
            encodings = tokenizer(sample.strip(), return_tensors='pt')
            input_ids = encodings.input_ids.to(device)
            target_ids = input_ids.clone()

            # Initialize loss function
            loss_fct = CrossEntropyLoss(reduction='none')

            with torch.no_grad():
                outputs = model(input_ids)
                logits = outputs.logits
                
                # Clean up CUDA memory
                del outputs
                torch.cuda.empty_cache()

                # Shift logits and target_ids for next-token prediction
                shift_logits = logits[..., :-1, :].contiguous()
                shift_target_ids = target_ids[..., 1:].contiguous()

                # Calculate loss
                loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 
                              shift_target_ids.view(-1)).cpu()
                del shift_logits, shift_target_ids
                
                # Calculate perplexity
                avg_nll = loss.mean()
                ppl = torch.exp(avg_nll).item()
                del loss, avg_nll
                
                progress_bar.set_postfix({'Perplexity': f'{ppl:.2f}'})
                perplexities.append((i+1, ppl))
                
                writer.writerow([i+1, ppl])
    
    return perplexities

In [40]:
text_file_path = "interp-data/code.json"

for k in range(6, 0, -1):
    ppl = calculate_perplexity(model=model,
                              tokenizer=tokenizer,
                              json_file_path=text_file_path,
                              domain="code",
                              top_k=k)
    print(f"Perplexity (top-{k}): {ppl}")


Processing code samples (top-k=6): 100%|██████████| 200/200 [05:02<00:00,  1.51s/it, Perplexity=1.59] 


Perplexity (top-6): [(1, 3.6393380165100098), (2, 2.95393443107605), (3, 1.7198283672332764), (4, 1.475581169128418), (5, 2.2476859092712402), (6, 3.5342817306518555), (7, 2.2424824237823486), (8, 5.658669471740723), (9, 1.3869194984436035), (10, 1.2824623584747314), (11, 2.2240259647369385), (12, 4.973986625671387), (13, 2.373444080352783), (14, 1.9076812267303467), (15, 6.475954055786133), (16, 7.0264811515808105), (17, 3.963676929473877), (18, 3.849867343902588), (19, 1.5237327814102173), (20, 2.1962056159973145), (21, 2.310671091079712), (22, 2.1384129524230957), (23, 2.4464073181152344), (24, 1.8635361194610596), (25, 4.023367881774902), (26, 1.5337620973587036), (27, 6.276959419250488), (28, 1.991531252861023), (29, 1.791770100593567), (30, 1.5346059799194336), (31, 5.477392196655273), (32, 2.0099716186523438), (33, 1.975020408630371), (34, 3.181039333343506), (35, 2.728817939758301), (36, 2.6321280002593994), (37, 2.389310121536255), (38, 3.086223602294922), (39, 2.5341613292694

Processing code samples (top-k=5): 100%|██████████| 200/200 [04:34<00:00,  1.37s/it, Perplexity=1.59] 


Perplexity (top-5): [(1, 3.64931058883667), (2, 3.0740864276885986), (3, 1.7237818241119385), (4, 1.4943830966949463), (5, 2.2486321926116943), (6, 3.4931201934814453), (7, 2.302518844604492), (8, 5.662024021148682), (9, 1.383519172668457), (10, 1.2793495655059814), (11, 2.234905481338501), (12, 5.021279335021973), (13, 2.346092462539673), (14, 1.92448091506958), (15, 6.504097938537598), (16, 7.132164478302002), (17, 3.991126775741577), (18, 3.862718343734741), (19, 1.5294129848480225), (20, 2.2382235527038574), (21, 2.3135645389556885), (22, 2.2159619331359863), (23, 2.4689905643463135), (24, 1.8904926776885986), (25, 4.044453144073486), (26, 1.533766269683838), (27, 6.421784400939941), (28, 1.9970197677612305), (29, 1.7893403768539429), (30, 1.5343058109283447), (31, 5.553557395935059), (32, 2.0042362213134766), (33, 1.9945217370986938), (34, 3.148193836212158), (35, 2.692101001739502), (36, 2.628990650177002), (37, 2.3956258296966553), (38, 3.158074140548706), (39, 2.544734716415405

Processing code samples (top-k=4): 100%|██████████| 200/200 [04:01<00:00,  1.21s/it, Perplexity=1.59] 


Perplexity (top-4): [(1, 3.706648826599121), (2, 3.1796324253082275), (3, 1.718397617340088), (4, 1.5084750652313232), (5, 2.26690936088562), (6, 3.4934091567993164), (7, 2.353666305541992), (8, 5.754793167114258), (9, 1.3809739351272583), (10, 1.2755281925201416), (11, 2.2242302894592285), (12, 4.9840617179870605), (13, 2.3333959579467773), (14, 1.9360296726226807), (15, 6.611832141876221), (16, 7.389785289764404), (17, 4.009702682495117), (18, 3.9211108684539795), (19, 1.5403488874435425), (20, 2.312075614929199), (21, 2.3011422157287598), (22, 2.265432834625244), (23, 2.4856433868408203), (24, 1.907849907875061), (25, 3.971853017807007), (26, 1.5450918674468994), (27, 6.669025897979736), (28, 1.987296462059021), (29, 1.8177576065063477), (30, 1.5275741815567017), (31, 5.730889797210693), (32, 2.016425848007202), (33, 2.013301134109497), (34, 3.2982988357543945), (35, 2.666780948638916), (36, 2.597932815551758), (37, 2.35581374168396), (38, 3.1769914627075195), (39, 2.570407390594482

Processing code samples (top-k=3): 100%|██████████| 200/200 [03:32<00:00,  1.06s/it, Perplexity=1.60] 


Perplexity (top-3): [(1, 3.748962879180908), (2, 3.460726261138916), (3, 1.742380976676941), (4, 1.5181143283843994), (5, 2.223583698272705), (6, 3.478574275970459), (7, 2.380861520767212), (8, 6.079816818237305), (9, 1.39162015914917), (10, 1.265258550643921), (11, 2.2579219341278076), (12, 5.001200199127197), (13, 2.3857338428497314), (14, 1.9240682125091553), (15, 6.816288471221924), (16, 8.025300979614258), (17, 4.0658674240112305), (18, 4.085056781768799), (19, 1.546324372291565), (20, 2.4120638370513916), (21, 2.297670841217041), (22, 2.2391276359558105), (23, 2.4699623584747314), (24, 1.9367371797561646), (25, 3.9714553356170654), (26, 1.549174427986145), (27, 6.967252731323242), (28, 2.0143795013427734), (29, 1.8267953395843506), (30, 1.5491328239440918), (31, 5.859408378601074), (32, 2.0506625175476074), (33, 2.134036064147949), (34, 3.37134051322937), (35, 2.6940388679504395), (36, 2.588442802429199), (37, 2.384453296661377), (38, 3.16365909576416), (39, 2.622987985610962), (

Processing code samples (top-k=2): 100%|██████████| 200/200 [07:02<00:00,  2.11s/it, Perplexity=1.66] 


Perplexity (top-2): [(1, 3.8662467002868652), (2, 4.634333610534668), (3, 1.7419947385787964), (4, 1.5546294450759888), (5, 2.4085428714752197), (6, 3.6045382022857666), (7, 2.4202709197998047), (8, 7.094555854797363), (9, 1.418639898300171), (10, 1.2889816761016846), (11, 2.3331704139709473), (12, 5.432575225830078), (13, 2.4318220615386963), (14, 2.0694332122802734), (15, 7.479215145111084), (16, 10.885858535766602), (17, 4.6522345542907715), (18, 4.251559257507324), (19, 1.5735458135604858), (20, 2.9462242126464844), (21, 2.3267951011657715), (22, 2.657407522201538), (23, 2.536320924758911), (24, 2.013063907623291), (25, 4.490602016448975), (26, 1.5586919784545898), (27, 7.826318740844727), (28, 2.0726373195648193), (29, 1.912284016609192), (30, 1.5866847038269043), (31, 6.173233985900879), (32, 2.130819797515869), (33, 2.526991844177246), (34, 3.684293746948242), (35, 3.1529016494750977), (36, 2.664602279663086), (37, 2.4257185459136963), (38, 3.277728319168091), (39, 2.94692325592

Processing code samples (top-k=1): 100%|██████████| 200/200 [03:55<00:00,  1.18s/it, Perplexity=2.08] 

Perplexity (top-1): [(1, 7.629128456115723), (2, 6.170047760009766), (3, 1.9111443758010864), (4, 1.7313601970672607), (5, 2.9305224418640137), (6, 4.474883079528809), (7, 2.9278531074523926), (8, 10.674482345581055), (9, 1.5722403526306152), (10, 1.4518835544586182), (11, 2.6894078254699707), (12, 6.464173793792725), (13, 2.782701015472412), (14, 2.6019043922424316), (15, 11.26655101776123), (16, 14.24222469329834), (17, 6.749372482299805), (18, 4.939309597015381), (19, 1.8154406547546387), (20, 3.5561578273773193), (21, 2.7677488327026367), (22, 4.385662078857422), (23, 2.910689353942871), (24, 2.520022392272949), (25, 6.458834171295166), (26, 1.7105642557144165), (27, 13.653885841369629), (28, 2.3354501724243164), (29, 2.4548184871673584), (30, 1.8848669528961182), (31, 7.571052074432373), (32, 2.37091326713562), (33, 3.323946237564087), (34, 5.306005477905273), (35, 3.955907106399536), (36, 4.056483745574951), (37, 3.156825304031372), (38, 3.675379514694214), (39, 4.043042659759521




In [29]:
import pandas as pd
import plotly.express as px

# Read the CSV file
df = pd.read_csv('english_perplexity_top6.csv')

# Create the line plot using plotly
fig = px.line(df, x='text_num', y='perplexity', 
              markers=True,
              title='Perplexity by Input Number')

# Customize the plot
fig.update_layout(
    xaxis_title="Text Number",
    yaxis_title="Perplexity",
    xaxis=dict(showgrid=True),
    yaxis=dict(showgrid=True)
)

# Display the plot
fig.show()


In [68]:
import pandas as pd
import plotly.express as px
import numpy as np

# Read all CSV files and calculate means of log perplexities
files = [f'pplx-data/french_perplexity_top{k}.csv' for k in range(1,7)]
labels = [f'Top {k}' for k in range(1,7)]
log_means = []

for file in files:
    # Read CSV file
    df = pd.read_csv(file)
    # Calculate mean of log perplexities
    log_perplexity = np.log(df['perplexity'])
    log_mean = log_perplexity.mean()
    log_means.append(log_mean)
    
    # Print both log and perplexity space results for reference
    k = file.split('top')[1].split('.')[0]
    print(f"{k}:")
    print(f"  Mean log perplexity: {log_mean:.3f}")
    print(f"  Equivalent perplexity: {np.exp(log_mean):.3f}")

# Create DataFrame for plotting
plot_df = pd.DataFrame({
    'Selection': labels,
    'Log Perplexity': log_means
})

# Create line plot
fig = px.line(plot_df, 
              x='Selection', 
              y='Log Perplexity',
              markers=True,
              text=[f'{v:.3f}' for v in log_means],
              title='Mean Log Perplexity by Top-k Selection')
# Customize the plot
fig.update_layout(
    xaxis_title="Selection Method",
    yaxis_title="Mean Log Perplexity", 
    xaxis=dict(showgrid=True)
)

# Update marker and text positions
fig.update_traces(textposition="middle right")

# Display the plot
fig.show()

1:
  Mean log perplexity: 4.232
  Equivalent perplexity: 68.882
2:
  Mean log perplexity: 3.566
  Equivalent perplexity: 35.377
3:
  Mean log perplexity: 3.325
  Equivalent perplexity: 27.792
4:
  Mean log perplexity: 3.289
  Equivalent perplexity: 26.821
5:
  Mean log perplexity: 3.271
  Equivalent perplexity: 26.335
6:
  Mean log perplexity: 3.279
  Equivalent perplexity: 26.537
