Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More robust Online DPO changes for RL update #1664

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update fast_lora.py, added mixed precising pytorch autocasting
  • Loading branch information
pluesclues authored Dec 17, 2024
commit 389b98f4860ab007f02af27258bd68d594749a66
235 changes: 80 additions & 155 deletions unsloth/kernels/fast_lora.py
Original file line number Diff line number Diff line change
@@ -113,34 +113,36 @@ def backward(ctx, dY : torch.Tensor):
h, df, de = DW, e, g

# Down projection LoRA weights
d_downA = h.t() @ (dY @ downB.t())
d_downB = (downA.t() @ h.t()) @ dY
d_downA *= downS
d_downB *= downS

# Up projection LoRA weights
d_upA = X.t() @ (df @ upB.t())
d_upB = (upA.t() @ X.t()) @ df
d_upA *= upS
d_upB *= upS

# Gate projection LoRA weights
d_gateA = X.t() @ (de @ gateB.t())
d_gateB = (gateA.t() @ X.t()) @ de
d_gateA *= gateS
d_gateB *= gateS

# dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
# dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
upW = fast_dequantize(upW.t(), upW_quant)
dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None)
del upW
dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())

gateW = fast_dequantize(gateW.t(), gateW_quant)
dX += de @ gateW.t()
del gateW
dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
d_downA = h.t() @ (dY @ downB.t().to(torch.bfloat16))
#breakpoint()
d_downB = (downA.t() @ h.t()) @ dY
d_downA *= downS
d_downB *= downS

# Up projection LoRA weights
d_upA = X.t() @ (df @ upB.t())
d_upB = (upA.t() @ X.t()) @ df
d_upA *= upS
d_upB *= upS

# Gate projection LoRA weights
d_gateA = X.t() @ (de @ gateB.t())
d_gateB = (gateA.t() @ X.t()) @ de
d_gateA *= gateS
d_gateB *= gateS

# dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
# dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
upW = fast_dequantize(upW.t(), upW_quant)
dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None)
del upW
dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())

gateW = fast_dequantize(gateW.t(), gateW_quant)
dX += de @ gateW.t()
del gateW
dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())

# gateW, gateW_quant, gateA, gateB, gateS,
# upW, upW_quant, upA, upB, upS,
@@ -270,45 +272,45 @@ def backward(ctx, dQ, dK, dV):

### Weight projection LoRA weights
# See our blogpost for more details.

# Q Projection
d_QA = X.t() @ (dQ @ QB.t())
d_QB = (QA.t() @ X.t()) @ dQ
d_QA *= QS
d_QB *= QS

# K Projection
d_KA = X.t() @ (dK @ KB.t())
d_KB = (KA.t() @ X.t()) @ dK
d_KA *= KS
d_KB *= KS

# V Projection
d_VA = X.t() @ (dV @ VB.t())
d_VB = (VA.t() @ X.t()) @ dV
d_VA *= VS
d_VB *= VS

# Combine derivatives to find dX
# dQ
QW = fast_dequantize(QW.t(), QW_quant)
dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None)
del QW
dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))

# dK
KW = fast_dequantize(KW.t(), KW_quant)
dX += dK @ KW.t()
del KW
dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())

# dV
VW = fast_dequantize(VW.t(), VW_quant)
dX += dV @ VW.t()
del VW
dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())

# QW, QW_quant, QA, QB, QS,
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
# Q Projection
d_QA = X.t() @ (dQ @ QB.t())
d_QB = (QA.t() @ X.t()) @ dQ
d_QA *= QS
d_QB *= QS

# K Projection
d_KA = X.t() @ (dK @ KB.t())
d_KB = (KA.t() @ X.t()) @ dK
d_KA *= KS
d_KB *= KS

# V Projection
d_VA = X.t() @ (dV @ VB.t())
d_VB = (VA.t() @ X.t()) @ dV
d_VA *= VS
d_VB *= VS

# Combine derivatives to find dX
# dQ
QW = fast_dequantize(QW.t(), QW_quant)
dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None)
del QW
dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))

# dK
KW = fast_dequantize(KW.t(), KW_quant)
dX += dK @ KW.t()
del KW
dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())

# dV
VW = fast_dequantize(VW.t(), VW_quant)
dX += dV @ VW.t()
del VW
dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())

# QW, QW_quant, QA, QB, QS,
# KW, KW_quant, KA, KB, KS,
# VW, VW_quant, VA, VB, VS,
return dX.view(batch, seq_len, hd), \
@@ -387,16 +389,17 @@ def backward(ctx, dY : torch.Tensor):

### Weight projection LoRA weights
# Weight projection
d_A = X.t() @ (dY @ B.t())
d_B = (A.t() @ X.t()) @ dY
d_A *= S
d_B *= S

# Get derivative for dX
W = fast_dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W
dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
d_A = X.t() @ (dY @ B.t())
d_B = (A.t() @ X.t()) @ dY
d_A *= S
d_B *= S

# Get derivative for dX
W = fast_dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W
dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())

# W, W_quant, A, B, S
return dX.view(batch, seq_len, hd), \
@@ -410,81 +413,3 @@ def apply_lora_o(self, X):
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
return O
pass


IDENTITY_DROPOUT = torch.nn.Identity
@torch._disable_dynamo
def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError(
"Unsloth: Currently not supported yet - reshaping done incorrectly"
)
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)

if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
# Fastpath
if len(self.active_adapters) == 1:
active_adapter = self.active_adapters[0]
if active_adapter not in self.lora_A.keys(): return self.base_layer(x, *args, **kwargs)

dropout = self.lora_dropout[active_adapter]
if isinstance(dropout, IDENTITY_DROPOUT) and not self.use_dora[active_adapter]:
lora_A = self.lora_A[active_adapter].weight
lora_B = self.lora_B[active_adapter].weight
scaling = self.scaling[active_adapter]
W = self.base_layer.weight
return LoRA_W.apply(x, W, QUANT_STATE(W), lora_A, lora_B, scaling)
pass
pass

result = self.base_layer(x, *args, **kwargs)
# As per Tim Dettmers, for 4bit, we need to defensively clone here.
# The reason is that in some cases, an error can occur that backprop
# does not work on a manipulated view. This issue may be solved with
# newer PyTorch versions but this would need extensive testing to be
# sure.
result = result.clone()

for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]

requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)

if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
if isinstance(dropout, torch.nn.Identity) or not self.training:
base_result = result
else:
x = dropout(x)
base_result = None

result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
base_result=base_result,
)
if requires_conversion:
result = result.to(expected_dtype)

return result
pass