Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More robust Online DPO changes for RL update #1664

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update fast_lora.py, removing downcasting stuff
  • Loading branch information
pluesclues authored Feb 14, 2025
commit 06e694f646185b8d5294b9ff94cf1b38cfbd7540
158 changes: 78 additions & 80 deletions unsloth/kernels/fast_lora.py
Original file line number Diff line number Diff line change
@@ -113,36 +113,34 @@ def backward(ctx, dY : torch.Tensor):
h, df, de = DW, e, g

# Down projection LoRA weights
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
d_downA = h.t() @ (dY @ downB.t().to(torch.bfloat16))
#breakpoint()
d_downB = (downA.t() @ h.t()) @ dY
d_downA *= downS
d_downB *= downS

# Up projection LoRA weights
d_upA = X.t() @ (df @ upB.t())
d_upB = (upA.t() @ X.t()) @ df
d_upA *= upS
d_upB *= upS

# Gate projection LoRA weights
d_gateA = X.t() @ (de @ gateB.t())
d_gateB = (gateA.t() @ X.t()) @ de
d_gateA *= gateS
d_gateB *= gateS

# dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
# dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
upW = fast_dequantize(upW.t(), upW_quant)
dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None)
del upW
dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())

gateW = fast_dequantize(gateW.t(), gateW_quant)
dX += de @ gateW.t()
del gateW
dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
d_downA = h.t() @ (dY @ downB.t())
d_downB = (downA.t() @ h.t()) @ dY
d_downA *= downS
d_downB *= downS

# Up projection LoRA weights
d_upA = X.t() @ (df @ upB.t())
d_upB = (upA.t() @ X.t()) @ df
d_upA *= upS
d_upB *= upS

# Gate projection LoRA weights
d_gateA = X.t() @ (de @ gateB.t())
d_gateB = (gateA.t() @ X.t()) @ de
d_gateA *= gateS
d_gateB *= gateS

# dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
# dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
upW = fast_dequantize(upW.t(), upW_quant)
dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None)
del upW
dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())

gateW = fast_dequantize(gateW.t(), gateW_quant)
dX += de @ gateW.t()
del gateW
dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())

# gateW, gateW_quant, gateA, gateB, gateS,
# upW, upW_quant, upA, upB, upS,
@@ -272,45 +270,45 @@ def backward(ctx, dQ, dK, dV):

### Weight projection LoRA weights
# See our blogpost for more details.
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
# Q Projection
d_QA = X.t() @ (dQ @ QB.t())
d_QB = (QA.t() @ X.t()) @ dQ
d_QA *= QS
d_QB *= QS

# K Projection
d_KA = X.t() @ (dK @ KB.t())
d_KB = (KA.t() @ X.t()) @ dK
d_KA *= KS
d_KB *= KS

# V Projection
d_VA = X.t() @ (dV @ VB.t())
d_VB = (VA.t() @ X.t()) @ dV
d_VA *= VS
d_VB *= VS

# Combine derivatives to find dX
# dQ
QW = fast_dequantize(QW.t(), QW_quant)
dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None)
del QW
dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))

# dK
KW = fast_dequantize(KW.t(), KW_quant)
dX += dK @ KW.t()
del KW
dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())

# dV
VW = fast_dequantize(VW.t(), VW_quant)
dX += dV @ VW.t()
del VW
dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())

# QW, QW_quant, QA, QB, QS,

# Q Projection
d_QA = X.t() @ (dQ @ QB.t())
d_QB = (QA.t() @ X.t()) @ dQ
d_QA *= QS
d_QB *= QS

# K Projection
d_KA = X.t() @ (dK @ KB.t())
d_KB = (KA.t() @ X.t()) @ dK
d_KA *= KS
d_KB *= KS

# V Projection
d_VA = X.t() @ (dV @ VB.t())
d_VB = (VA.t() @ X.t()) @ dV
d_VA *= VS
d_VB *= VS

# Combine derivatives to find dX
# dQ
QW = fast_dequantize(QW.t(), QW_quant)
dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None)
del QW
dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))

# dK
KW = fast_dequantize(KW.t(), KW_quant)
dX += dK @ KW.t()
del KW
dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())

# dV
VW = fast_dequantize(VW.t(), VW_quant)
dX += dV @ VW.t()
del VW
dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())

# QW, QW_quant, QA, QB, QS,
# KW, KW_quant, KA, KB, KS,
# VW, VW_quant, VA, VB, VS,
return dX.view(batch, seq_len, hd), \
@@ -389,17 +387,16 @@ def backward(ctx, dY : torch.Tensor):

### Weight projection LoRA weights
# Weight projection
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
d_A = X.t() @ (dY @ B.t())
d_B = (A.t() @ X.t()) @ dY
d_A *= S
d_B *= S

# Get derivative for dX
W = fast_dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W
dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
d_A = X.t() @ (dY @ B.t())
d_B = (A.t() @ X.t()) @ dY
d_A *= S
d_B *= S

# Get derivative for dX
W = fast_dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W
dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())

# W, W_quant, A, B, S
return dX.view(batch, seq_len, hd), \
@@ -414,6 +411,7 @@ def apply_lora_o(self, X):
return O
pass


IDENTITY_DROPOUT = torch.nn.Identity
@torch._disable_dynamo
def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: