In [1]:
from utils import *
import os
from contextlib import contextmanager
from tqdm import tqdm

@contextmanager
def cuda_timer(label="Timer"):
    # Create CUDA events
    s = torch.cuda.Event(enable_timing=True)
    e = torch.cuda.Event(enable_timing=True)

    # Start the timer
    s.record()

    try:
        # Execute the block of code
        yield
    finally:
        # End the timer
        e.record()

        # Synchronize and calculate the elapsed time
        torch.cuda.synchronize()
        elapsed_time = s.elapsed_time(e)

        # Print the result in milliseconds
        print(f'{label}: {elapsed_time} ms')
        
def _t(shp, dtype=dtypes.bfloat16): return Tensor.empty(shp, dtype=dtype)

def _parse_dtype(dtype): return str(dtype).split('.')[1]

def to_torch(tensor): 
    arr = tensor.numpy()
    return torch.from_numpy(arr).to(
        getattr(torch, _parse_dtype(tensor.dtype))
    ).to('cuda')

def dict_mapper(_dict, func):
    return {k: func(v) for k, v in _dict.items()}

In [2]:
BS = 2 # batch size
GUIDANCE = 3.5

inp = dict(
    img=_t((BS, 1024, 64)),
    img_ids=_t((BS, 1024, 3)),
    txt=_t((BS, 256, 4096)),
    txt_ids=_t((BS, 256, 3)),
    vec=_t((BS, 768)),
)
db_inp = dict(
    img=_t((BS, 1024, 3072)),
    txt=_t((BS, 256, 3072)),
    vec=_t((BS, 3072)),
    pe=_t((BS, 1, 1280, 64, 2, 2)),
)
  
sb_inp = dict(
    img=_t((BS, 1280, 3072)),
    vec=_t((BS, 3072)),
    pe=_t((BS, 1, 1280, 64, 2, 2)),
) 
# timesteps = get_schedule(
#   num_steps, 
#   inp["img"].shape[1], 
#   shift=(args.name != "flux-schnell")
# )
timesteps_inp = [1.0, 0.75, 0.5, 0.25, 0.0]
vec = _t((1, 768))
t_vec = _t((1,))

In [5]:
compiled_original_model = torch.compile(get_original_flow())
inp_pt = dict_mapper(inp, to_torch)
def _f():
    compiled_original_model(
        img=inp_pt['img'], img_ids=inp_pt['img_ids'], 
        txt=inp_pt['txt'], txt_ids=inp_pt['txt_ids'],
        timesteps=to_torch(t_vec), 
        y=to_torch(vec),
    )

  from .autonotebook import tqdm as notebook_tqdm
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Init model
Loading checkpoint
Init AE


In [6]:
# warmup
with cuda_timer('warmup'):
    for _ in tqdm(range(10)): 
        _f()

with cuda_timer('inference'):
    _f()

100%|██████████| 10/10 [01:16<00:00,  7.66s/it]


warmup: 76692.375 ms
inference: 70.80608367919922 ms


In [3]:
sched = get_sched_flux(inp, db_inp, sb_inp, t_vec, vec)

Generating with seed 1726985112:
a horse sized cat eating a bagel
Init model


ram used:  2.64 GB, double_blocks.3.txt_attn.proj.bias                :  13%|▏| 

ram used: 23.78 GB, final_layer.adaLN_modulation.1.bias               : 100%|█| 


loaded weights in 9357.87 ms, 23.78 GB loaded at 2.54 GB/s


In [4]:
# def _setenv(var, val): os.environ[var] = val
# (_setenv(i, j) for i, j in [('MCTS', '100'), ('CACHELEVEL', '0'), ("DEBUG", "1"), ("SRC", "0")])
# os.environ['MCTS'] = '100'
# os.environ['IGNORE_MCTS_CACHE'] = '0'
# os.environ['CACHELEVEL'] = '1'
# os.environ['DEBUG'] = '1'
# os.environ['SRC'] = '0'

In [5]:
# tinygrad_flow_lin = mcts(sched)

In [6]:
device: Compiled = Device[Device.DEFAULT]
# if getenv("BACKWARD"): Tensor.training = True
#     print(f"optimizing for {device}")

    # sched = get_sched_flux_flow_model()
sched = [x for x in sched if x.ast.op is UOps.SINK]
# focus on one kernel
if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1]
DEBUG = 3
# work with the schedule
total_tm = 0
running_gflops = 0
usage = {}
kernels = []
for i,si in enumerate(sched):
    if DEBUG >= 3: print(si.ast)
    
    rawbufs = bufs_from_lin(Kernel(si.ast))
    # "linearize" the op into uops in different ways
    lins: List[Tuple[Kernel, str]] = []
    
    # # raw
    # lin = Kernel(si.ast, opts=device.renderer)
    # # lin.hand_coded_optimizations()
    # lins.append((lin, "RAW"))
    
    # # always try hand coded opt
    # lin = Kernel(si.ast, opts=device.renderer)
    # lin.hand_coded_optimizations()
    # lins.append((lin, "HC"))
    
    # # maybe try tensor cores
    # lin = Kernel(si.ast, opts=device.renderer)
    # if lin.apply_tensor_cores():
    #     lins.append((lin, "TC"))
    
    # # try a beam search
    # if beam:=getenv("BEAM"):
    #     lin = Kernel(si.ast, opts=device.renderer)
    #     lin = beam_search(lin, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1)))
    #     lins.append((lin, "BEAM"))
    
    # try MCTS
    # if mcts:=getenv("MCTS"):
    lin = Kernel(si.ast, opts=device.renderer)
    lin = mcts_search(lin, rawbufs, True)
    kernels.append(lin)
        # lins.append((lin, "MCTS"))

        

UOp(UOps.SINK, dtypes.void, arg=None, src=(
  UOp(UOps.STORE, dtypes.void, arg=None, src=(
    UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
    UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(8, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
    UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
      UOp(UOps.ALU, dtypes.float, arg=UnaryOps.EXP2, src=(
        UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
          UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
            UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
              UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
                UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
                  UOp(UOps.CONST, dtypes.float, arg=2.0, src=(
                    UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(9, 15), strides=(0, 0), offset=0, mask=((0, 9), (7, 15)), contiguous=False), V

UOp(UOps.SINK, dtypes.void, arg=None, src=(
  UOp(UOps.STORE, dtypes.void, arg=None, src=(
    UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bfloat16), arg=0, src=()),
    UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 3072), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
    UOp(UOps.ALU, dtypes.bfloat16, arg=BinaryOps.ADD, src=(
      UOp(UOps.CONST, dtypes.bfloat16, arg=1.0, src=(
        UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 3072), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
      UOp(UOps.LOAD, dtypes.bfloat16, arg=None, src=(
        UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bfloat16), arg=1, src=()),
        UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 3072), strides=(0, 0, 1), offset=3072, mask=None, contiguous=False),)), src=()),)),)),)),))
UOp(UOps.SINK, dtypes.void, arg=None, src=(
  UOp(UOps.STORE, dtypes.void, arg=None, src=(


In [7]:
len(kernels)

1425

In [8]:
len(sched)

1425

In [9]:
kernels[0].to_program().op_estimate

480

In [10]:

from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin

In [11]:
buf = bufs_from_lin(kernels[0])

In [12]:
print(buf[0].__dir__())

['device', 'size', 'dtype', 'options', 'offset', '_base', '_lb_refcount', 'allocator', '_buf', '__module__', '__init__', 'base', 'lb_refcount', 'ref', 'is_allocated', 'ensure_allocated', 'allocate', '__reduce__', 'nbytes', '__del__', '__repr__', 'as_buffer', 'copyin', 'copyout', 'view', '__dict__', '__weakref__', '__doc__', '__new__', '__hash__', '__str__', '__getattribute__', '__setattr__', '__delattr__', '__lt__', '__le__', '__eq__', '__ne__', '__gt__', '__ge__', '__reduce_ex__', '__subclasshook__', '__init_subclass__', '__format__', '__sizeof__', '__dir__', '__class__']


In [13]:
krn = kernels[0]

In [14]:
print(krn.__dir__())

['opts', 'ast', 'reduceops', 'vars', 'bufs', 'full_buf_index', 'sts', 'applied_opts', 'group_for_reduces', 'upcasted', 'local_dims', 'dont_use_locals', 'tensor_core', 'tensor_core_opts', 'bufs_for_tensor_core', 'use_tensor_cores', 'uops', 'name', '__module__', '__annotations__', '__init__', 'copy', 'membufs', 'float4_axis', 'upcasted_axis', 'first_reduce', 'first_upcast', 'reduceop', 'output_shape', 'full_shape', 'full_unupcasted_shape', 'shape_len', 'upcast_in_mid_reduce_axes', 'global_dims', 'colors', 'colored_shape', 'reshape_and_permute', 'upcast', 'shift_to', 'simplify_ones', 'simplify_merge_adjacent', '_create_tc_opts', '_apply_tc_opt', 'apply_tensor_cores', 'apply_opt', 'required_optimizations', 'hand_coded_optimizations', 'kernel_cnt', 'get_optimized_ast', 'linearize', 'to_program', '__dict__', '__weakref__', '__doc__', '__new__', '__repr__', '__hash__', '__str__', '__getattribute__', '__setattr__', '__delattr__', '__lt__', '__le__', '__eq__', '__ne__', '__gt__', '__ge__', '__r

In [15]:
prg = krn.to_program()

In [17]:
print(prg.__dir__())

['name', 'src', 'dname', 'uops', 'mem_estimate', 'global_size', 'local_size', 'vars', 'globals', 'outs', '_ran_post_init', '__module__', '__annotations__', '__post_init__', 'op_estimate', 'lds_estimate', '_ops_lds', 'outcount', 'function_name', 'launch_dims', '__dict__', '__weakref__', '__doc__', '__dataclass_params__', '__dataclass_fields__', '__init__', '__repr__', '__eq__', '__hash__', '__match_args__', '__new__', '__str__', '__getattribute__', '__setattr__', '__delattr__', '__lt__', '__le__', '__ne__', '__gt__', '__ge__', '__reduce_ex__', '__reduce__', '__subclasshook__', '__init_subclass__', '__format__', '__sizeof__', '__dir__', '__class__']


In [18]:
prg.__module__

'tinygrad.renderer'

In [25]:
from tinygrad.engine.realize import Runner, get_kernel

In [21]:
runner = Runner('dp', 'name')

In [23]:
runner(buf, {})

NotImplementedError: override this

In [24]:
import os;os.environ['BEAM']='100'

In [None]:
get_kernel()

In [26]:
krn.ast

UOp(UOps.SINK, dtypes.void, arg=None, src=(
  UOp(UOps.STORE, dtypes.void, arg=None, src=(
    UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
    UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(8, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
    UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
      UOp(UOps.ALU, dtypes.float, arg=UnaryOps.EXP2, src=(
        UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
          UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
            UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
              UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
                UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
                  UOp(UOps.CONST, dtypes.float, arg=2.0, src=(
                    UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(9, 15), strides=(0, 0), offset=0, mask=((0, 9), (7, 15)), contiguous=False), V