-
Notifications
You must be signed in to change notification settings - Fork 91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
测试源码中给的qlora.py报错 #125
Comments
这是旧版本的__main__函数了,需要你自己改一下。 |
这是qlora.py的执行代码,怎么改呢,一直做cv,才基础多模态大模型 if __name__ == '__main__':
class Model(nn.Module):
def __init__(self):
super().__init__()
self.child = nn.Linear(100, 200)
def forward(self, x):
return self.child(x)
model = Model()
torch.save(model.state_dict(), "linear.pt")
x = torch.randn(2, 100)
out1 = model(x)
model.child = LoraLinear(100, 200, 10)
model.load_state_dict(torch.load("linear.pt"), strict=False)
out2 = model(x)
torch.save(model.state_dict(), "lora.pt")
ckpt = torch.load("lora.pt")
breakpoint()
model.load_state_dict(ckpt, strict=False)
out3 = model(x)
breakpoint() |
我也忘记了,时间太久了,你自己读一下源码吧,也不长 |
好的,那我先试试 |
这样改还会报quant_state不能是None,这个quant_state该怎么添加 if __name__ == '__main__':
class Model(nn.Module):
def __init__(self):
super().__init__()
self.child = nn.Linear(100, 200)
def forward(self, x):
return self.child(x)
model = Model()
torch.save(model.state_dict(), "linear.pt")
x = torch.randn(2, 100)
out1 = model(x)
model.child = LoraLinear(nn.Linear, 5, 100, 200, 10, qlora=True)
model.load_state_dict(torch.load("linear.pt"), strict=False)
out2 = model(x)
torch.save(model.state_dict(), "lora.pt")
ckpt = torch.load("lora.pt")
breakpoint()
model.load_state_dict(ckpt, strict=False)
out3 = model(x)
breakpoint() |
需要在gpu上运行才会有quant_state。也就是说你需要 并且注意model.cuda只能调用一次,不然会出错(这是bitsandbytes的实现,我也控制不了,他们重载了.cuda()函数) |
确实只能.cuda()一次,给LoraLinear提前.cuda()就会报维度错误。 |
直接跑源码的qlora.py,报错
给model.child = LoraLinear(100, 200, 10)改为model.child = LoraLinear(100, 200, 10,10,2)后,又报错
The text was updated successfully, but these errors were encountered: