You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
thanks for the work you share. Please could you provide a clear explanation on how inference work?
I have watched your videos and still don't understand a 100% how:
1- The seqeunce is produced at training time
2- How the sequence is produced at test time
I saw your inference script but honestly, the whole thing is super blur to me.
How the sequence is produced at training time
So at training time we have the entire input and entire target sentences and all we have to do is to: Tokenize --> Numericalize --> Pad (so all are of equal length in the batch). I have separate videos where I go into more details on the data loading part and you could check out the torchtext videos for that. But after that both of these are inputted to the transformer and we utilize masking so that the network doesn't cheat by looking ahead in the target sentence (I've also gone into more depth on this in the transformer from scratch video).
How the sequence is produced at test time
Obviously at test time we don't have the entire target sentence but we have the input sentence, and what we do is that we try to output a single word at a time (that's we have for i in range(max_length)) loop in translate_sentence function. In the beginning we only have a start token for the target, but for each iteration in the for loop we gain one additional output predicted from the model (we take the highest probability prediction and append it to our outputs). We continue doing this in the for loop until we either a) reach a EOS token, or b) continue until max_length is reached.
Hi @aladdinpersson
thanks for the work you share. Please could you provide a clear explanation on how inference work?
I have watched your videos and still don't understand a 100% how:
1- The seqeunce is produced at training time
2- How the sequence is produced at test time
I saw your inference script but honestly, the whole thing is super blur to me.
Machine-Learning-Collection/ML/Pytorch/more_advanced/seq2seq_transformer/utils.py
Line 7 in 235beb2
The text was updated successfully, but these errors were encountered: