In [None]:

import sys
import os
sys.path.append(os.getcwd())

import torch
import torch.nn as nn
from models.rdt.model import RDT

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

try:
    # Guess 1: Maybe 170M is much smaller?
    model_guess_1 = RDT(
        hidden_size=768,
        depth=12,
        num_heads=12,
    )
    
    # Guess 2: 
    model_guess_2 = RDT(
        hidden_size=1024,
        depth=12,
        num_heads=16,
    )

    # The default in file
    model_file_default = RDT(hidden_size=1152, depth=28, num_heads=16)

    # 1B config from base.yaml
    model_1b = RDT(
        hidden_size=2048,
        depth=28,
        num_heads=32,
    )

    print(f"Guess 1 (768/12): {count_parameters(model_guess_1) / 1e6:.2f}M")
    print(f"Guess 2 (1024/12): {count_parameters(model_guess_2) / 1e6:.2f}M")
    print(f"File default (1152/28): {count_parameters(model_file_default) / 1e6:.2f}M")
    print(f"RDT-1B model params: {count_parameters(model_1b) / 1e6:.2f}M")
    
except Exception as e:
    import traceback
    traceback.print_exc()
