In [1]:
# install thunder
# !pip install lightning-thunder

In [2]:
import thunder
import torch


def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a * b


jfoo = thunder.jit(foo)
cfoo = torch.compile(foo)


def compile_bench(dim, device):
    print("dim >>", dim)
    %timeit foo(torch.randn(dim, dim,device=device), torch.randn(dim, dim,device=device))
    print("with thunder.jit")
    %timeit jfoo(torch.randn(dim, dim,device=device), torch.randn(dim, dim,device=device))
    print("with torch.compile")
    %timeit cfoo(torch.randn(dim, dim,device=device), torch.randn(dim, dim,device=device))
    print()

In [3]:
device = "cuda"
# compile_bench(10,device)
# compile_bench(100,device)
compile_bench(1000, device)
compile_bench(10000, device)

dim >> 1000
25.1 µs ± 965 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
with thunder.jit
24.9 µs ± 1.69 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
with torch.compile
57.8 µs ± 32.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

dim >> 10000
11.8 ms ± 3.9 ms per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
with thunder.jit
18 ms ± 385 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
with torch.compile
52.3 µs ± 29.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)



#### Demonstrating Speedups

In [4]:
import torchvision
from torchvision.models import densenet121

model = densenet121(weights=torchvision.models.densenet.DenseNet121_Weights)

model = model.cuda()
opt_model = torch.compile(model, mode="reduce-overhead")
thun_model = thunder.jit(model)



In [5]:
# INfrence

In [6]:
## let's repeat this
for _ in range(5):
    inp = torch.rand(128, 3, 128, 128).cuda()
    # Eager mode
    with torch.no_grad():
        %time model(inp)

print("~" * 30)

print("Compile mode")

for _ in range(5):
    inp = torch.rand(128, 3, 128, 128).cuda()
    # Compile mode
    with torch.no_grad():
        %time opt_model(inp)

print("Thunder jit")

for _ in range(5):
    inp = torch.rand(128, 3, 128, 128).cuda()
    # Compile mode
    with torch.no_grad():
        %time thun_model(inp)   # It is breaking

CPU times: user 263 ms, sys: 27.1 ms, total: 291 ms
Wall time: 291 ms
CPU times: user 4.14 ms, sys: 171 µs, total: 4.31 ms
Wall time: 4.32 ms
CPU times: user 4.18 ms, sys: 0 ns, total: 4.18 ms
Wall time: 4.19 ms
CPU times: user 4.38 ms, sys: 0 ns, total: 4.38 ms
Wall time: 4.39 ms
CPU times: user 6.05 ms, sys: 0 ns, total: 6.05 ms
Wall time: 6.07 ms
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Compile mode




CPU times: user 18.9 s, sys: 5.3 s, total: 24.2 s
Wall time: 24.7 s
CPU times: user 368 ms, sys: 55 µs, total: 368 ms
Wall time: 367 ms
CPU times: user 223 µs, sys: 224 µs, total: 447 µs
Wall time: 453 µs
CPU times: user 207 µs, sys: 208 µs, total: 415 µs
Wall time: 421 µs
CPU times: user 396 µs, sys: 0 ns, total: 396 µs
Wall time: 401 µs
Thunder jit


AttributeError: The torch language context has no method add_

AttributeError: The torch language context has no method add_

AttributeError: The torch language context has no method add_

AttributeError: The torch language context has no method add_

AttributeError: The torch language context has no method add_

In [7]:
# error : AttributeError: The torch language context has no method add_

# NOT IN A MOOD TO DEBUG AND FIX IT

In [8]:
# Thunder is in its early stages and should not be used for production runs yet.

# However, it can already deliver outstanding performance for pretraining and finetuning LLMs supported by LitGPT,

# Thunder is written entirely in Python. Even its trace is represented as valid Python at all stages of transformation. This allows unprecedented levels of introspection and extensibility.

In [None]:
import os

os.environ["TRANSFORMERS_CACHE"] = "/home/pranav-pc/projects/OpenTransformer/checkpoints"
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
from transformers import AutoModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)



Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]