Skip to content

Commit

Permalink
Remove dropout from LitGPTSDPABenchmark (#378)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed May 13, 2024
1 parent 82ffaac commit 83ccc18
Showing 1 changed file with 27 additions and 14 deletions.
41 changes: 27 additions & 14 deletions thunder/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2566,23 +2566,20 @@ def make_batch(self) -> tuple[list, dict]:
return (q, k, v), {"dropout": self.config.dropout}

def fn(self) -> Callable:
class nanoGPTScaledDotProductAttention(torch.nn.Module):
def __init__(slf):
super().__init__()

class ScaledDotProductAttention(torch.nn.Module):
def forward(slf, q, k, v, *, dropout):
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=dropout, is_causal=True
)

return nanoGPTScaledDotProductAttention()
return ScaledDotProductAttention()


class LitGPTSDPABenchmark(NanoGPTSDPABenchmark):
@classmethod
@property
def name(cls) -> str:
return "llama2-sdpa"
return "litgpt-sdpa"

@classmethod
@property
Expand All @@ -2591,21 +2588,37 @@ def description(cls) -> str:

def __init__(
self,
config: str = "Llama-2-7b-hf",
config: str | LitGPTConfig = "Llama-2-7b-hf",
batchdims: Sequence[int] = (16,),
device: str = "cuda",
dtype: dtypes.dtype = thunder.bfloat16,
requires_grad: bool = True,
) -> None:
from thunder.tests.litgpt_model import Config

litgptconfig = Config.from_name(config) if not isinstance(config, Config) else config
nanogptconfig = NanoGPTConfig(
n_head=litgptconfig.n_head,
seq_len=litgptconfig.block_size,
n_embd=litgptconfig.n_embd,
)
super().__init__(nanogptconfig, batchdims, device, dtype, requires_grad)
# not calling super().__init__() on purpose to avoid the nanogpt config validation
self.config = Config.from_name(config) if not isinstance(config, Config) else config

self.batchdims = batchdims
self.device = device
self.dtype = dtype
self.requires_grad: bool = requires_grad

# Performs torch dtype conversions
self.tdtype: torch.dtype = ltorch.to_torch_dtype(self.dtype)

# Sets required benchmark parameters
self.devices: list[str] = [device]

def make_batch(self) -> tuple[list, dict]:
make = partial(make_tensor, device=self.device, dtype=self.tdtype, requires_grad=self.requires_grad)
shape = self.batchdims + (self.config.n_head, self.config.block_size, self.config.head_size)

q = make(shape)
k = make(shape)
v = make(shape)

return (q, k, v), {"dropout": 0.0} # no litgpt model uses dropout


# Taken from HuggingFace Bart-Large model config:
Expand Down

0 comments on commit 83ccc18

Please sign in to comment.