New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
整合快速生成方法和删除无意义代码(Integrate fast generation method and delete meaningless code) #49
Conversation
为什么sample sequence不用trange了? |
这个trange只是单纯的打印进度信息,但是一般产生的字挺多的,直接满屏都是白块,当然我不知道是否是我自身的cmd有问题,无法只在一行显示,而是逐行打印。 |
你好,那个长度问题?为什么不修改? |
length这个情况本意就是生成多长,而不是想生成总长度多长的句子。要改也可以,没啥区别。不过这里确实有改进的空间,比如弄个length -1 即为自动设置生成最长长度的功能。 |
你好,我理解你说的length的本意,但是这个长度,有的人一使用,就想试试生成长文本,看看效果,此时随便设置个长度,而又恰好直接超过了训练时设置的步长,就直接gg,报错了。所以我个人觉得有必要加个判断 if(length+len(prefix)) > config.n_ctx:raise Exception("长度超过限制,请重新设置"),那么两个方法中的代码就不用改了,当然,如果你觉得真的没什么,那也没事,开源本身就是尊重彼此的意见,求同存异。 def fast_sample_sequence(model,context, length, temperature=1, top_k=0,top_p=0.0, device='cpu'):
inputs = torch.LongTensor(context).view(1, -1).to(device)
past = None
prev = inputs
generate = [] + context
with torch.no_grad():
for i in trange(length):
output = model(prev, past=past)
output, past = output[:2]
output = output[0,-1,:] / temperature
filtered_logits = top_k_top_p_filtering(output, top_k=top_k,top_p=top_p)
next_token = torch.multinomial(torch.softmax(filtered_logits, dim=-1),num_samples=1)
generate.append(next_token.item())
prev = next_token.view(1,1)
return generate |
可以,我来修改一下限制最长长度 |
你好,能说一下训练斗破,最后的loss收敛到多少吗? |
大概0.1的样子吧,记不太清楚了 |
谢谢 |
1 原先的快速生成方法,参数命名与默认的不一致,做了修改,当然上次我也没有意识到同一个方法,参数名居然不同,所以没有作测试,实在抱歉,不过这次我已经做了测试。
2 提供的fast_sample_sequence方法,返回的数据与要求的不匹配,做了修改,并添加了generate方法,根据命令行参数,动态调整使用的模式,默认采用原先的生成方式,经过个人测试,生成250个字,快了2秒。
3 删除无意义代码,比如xlnet判断,由于此是应用GPT-2生成文本,所以不需要加入此判断。