<a href="https://colab.research.google.com/github/Shopping-Yuan/ML2021HW/blob/Shopping_vscode_branch/HW05_modified.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

------
###Part 0 setting and installing package
------
###Part 1 preparing data set
------
######load data file
######clean data
######pick up line pairs
######tokenize : using sentencepiece
######make data set
------
###Part 2 make model
------
######positional encoding layer
######multihead attention layer
######encoder layer(s)
######decoder layer(s)
######transformer layer
------
###Part 3 training and validation
------
######label smoothing
######beam search
######bleu
------


setting
======
>Here are all parameters using in this project.

In [112]:
setting = {
# using in part 1
"data_info" : {
    "document":"/content",
    "raw_file_name":"/ted2020.tgz",
    "unzip_path":"/train_dev/",
    "source":{
        "lang":"en",
        "raw_data_path":"/train_dev/raw.en",
        "clean_data_path":"/train_dev/clean_en.txt",
        "tokenized_train_data":"/train_dev/tokenized_train_data.txt",
        "tokenized_val_data":"/train_dev/tokenized_val_data.txt"
        },
    "target":{
        "lang":"zh",
        "raw_data_path":"/train_dev/raw.zh",
        "clean_data_path":"/train_dev/clean_zh.txt",
        "tokenized_train_data":"/train_dev/tokenized_train_data.txt",
        "tokenized_val_data":"/train_dev/tokenized_val_data.txt"
        }
},
# using mainly in part 1
# but "vocab_size","pad_id","bos_id","eos_id","max_l" used in other parts
"tokenized_setting" : {
    "vocab_size" : 8000,
    "character_coverage" : 1,
    "model_type" : "unigram", # "bpe",
    "input_sentence_size" : 400000,
    "shuffle_input_sentence" : True,
    "normalization_rule_name" : "nmt_nfkc_cf",
    "pad_id":0,
    "unk_id":1,
    "bos_id":2,
    "eos_id":3,
    "max_l":400
},
# using mainly in part 3
"training_hparas" : {
    "batch_size" : 400
}
}

installing package
------

In [None]:
# used in part 1
!pip install sentencepiece
# used in part 1 and 3
!pip install tqdm
# used in part 2
!pip install torchinfo



preparing data set
=============

load data file
-------------
>Here I load dataset from my drive,  
>but it also can be download from the link below.

In [None]:
# step 1 : download dataset from drive to google colab
# original dataset is in "https://mega.nz/#!vEcTCISJ!3Rw0eHTZWPpdHBTbQEqBDikDEdFPr7fI8WxaXK9yZ9U"

path_doc = setting["data_info"]["document"]
rawdata_file_name = setting["data_info"]["raw_file_name"]
rawdata_file_path = path_doc + rawdata_file_name
unzip_path = path_doc + setting["data_info"]["unzip_path"]

# mount drive
from google.colab import drive
drive_path = path_doc + "/drive"
drive_name = "/MyDrive"
drive.mount(drive_path)

# copy file from drive
import shutil
shutil.copyfile(drive_path + drive_name + rawdata_file_name, rawdata_file_path)

# step 2 : unzip dataset
import tarfile
# open file
file = tarfile.open(rawdata_file_path)
# extracting file
file.extractall(unzip_path)
file.close()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


clean data
------
>First each dataset (source or target) is clean  
>seperately, change to halfwidth and remove/replace  
>some kind of punctuations.

>Also because the number of sentences in one line may be  
>different in line pairs of source and target set (its an error),  
>some special punctuations is add to the end of sentences  
>for the next process dealing with these problem by  
>using sentence pairs instead of lines pairs to form datasets.



In [None]:
import unicodedata
import string
import re
# convert fullwidth to halfwidth
def to_halfwidth(string):
  return "".join(unicodedata.normalize('NFKC',letter) for letter in string)
def clean_s_zh(s):
    s = to_halfwidth(s)
    # step 1 : delete — _
    delete = " _()[]"
    delete_rules = s.maketrans("","",delete)
    s = s.translate(delete_rules)

    # step 2 : replace “” with ""
    to_be_replace = '“”'
    replace = '""'
    replace_dict = dict(zip(to_be_replace,replace))

    # step 3 : add **END** before and after punctuation

    """
    The number of sentences in one line may be different
    in line pairs of source and target set.
    "**END**" is add after "。!?" and ".!?", which can be used
    to check if the number of sentence in the pair are equal
    in the next process.
    also in english, "." may be use in abbreviation,
    these different use must be identified.

    """

    punctuation = "。!?"
    for char in punctuation:
      replace_dict[char] = char + "**END**"

    replace_rules = s.maketrans(replace_dict)
    s = s.translate(replace_rules)

    zh_list = s.strip("\n").split("\n")

    return zh_list

def clean_s_en(s):
    s = to_halfwidth(s)

    replace_dict = {}

    delete = "-()[]"
    for char in delete:
      replace_dict[char] = ""

    punctuation = "!?"
    for char in punctuation:
      replace_dict[char] = char + "**END**"
    replace_rules = s.maketrans(replace_dict)
    s = s.translate(replace_rules)

    # Identify if "." is used in abbreviation,
    # if not, add "**END**" after it.
    pattern = re.compile(r"(?<!([.\s\r\n\f][a-zA-Z]))[.]")
    s = pattern.sub("**END**",s)

    # test pattern
    # pattern = re.compile(r"(?<!([.\s\r\n\f][a-zA-Z]))[.]")
    # result = pattern.sub("**END**","There are many people in U.S. w.r.t. in Taiwan.Thank you.")

    en_list = s.strip("\n").split("\n")

    return en_list

pick up line pairs
------
>pick up line pairs has equal number of sentences and  
>split them into sentences to form sourse/target dataset.  
>Remove sentences with too many words for training and validation.

In [None]:
# using "**END**" to split line pairs to check if they have equal sentence
def divide_by_END(s):
    list_s = []
    for line_string in s.strip("**END**").split("**END**"):
      if line_string not in [""," "]:
         list_s.append(line_string)
    return(list_s)
'''
warning : devide_en_again function is apply just beacause
in "this" dataset english sentences end with ":" or ";"
sometimes not splited well.
If the dataset is change, this part may need to be
eliminated or modified.
'''
def devide_en_again(s,punctuation = ":;"):
    replace_dict = {}
    for char in punctuation:
      replace_dict[char] = char + "**END**"

    replace_rules_src = s.maketrans(replace_dict)
    new_s = divide_by_END(s.translate(replace_rules_src))
    return new_s

# remove "sentence" if it is too long.
def remove_too_long(src_list,tgt_list,threshold = setting["tokenized_setting"]["max_l"]):
    too_long_src = 0
    too_long_tgt = 0
    remove = False
    new_s = []
    new_t = []
    for i in range(len(src_list)):
      if ((len(src_list[i])>threshold)):
        remove = True
        too_long_src += 1
      if (len(tgt_list[i])>threshold):
        remove = True
        too_long_tgt += 1
      if remove == False:
        new_s.append(src_list[i])
        new_t.append(tgt_list[i])
      else :
        remove = False
    return(new_s,new_t,too_long_src,too_long_tgt)

# pick up good line pairs for traning and validation model
def check_data_pairs(src_list,tgt_list):
    index = 0
    new_src_list = []
    new_tgt_list = []

    same = 0
    add_next = 0
    split_again = 0
    not_use = 0

    while(index < len(src_list)):

      src = divide_by_END(src_list[index])
      tgt = divide_by_END(tgt_list[index])

      # case 1 : src is as long as tgt , finished.
      if len(src) == len(tgt):
        new_src_list += src
        new_tgt_list += tgt
        same += 1
        index += 1

      else :
        # if it is not the last one : both src and tgt add next sentence
        if index != len(src_list)-1:
          src_add_next = divide_by_END(src_list[index] + src_list[index+1])
          tgt_add_next = divide_by_END(tgt_list[index] + tgt_list[index+1])
          # case 2 : src_add_next is as long as tgt_add_next , finished.
          if len(src_add_next) == len(tgt_add_next):
            new_src_list += src_add_next
            new_tgt_list += tgt_add_next
            add_next += 2
            index += 2

          # using new punctuation to divide tgt (english) sentence.
          # note that this part could cause negative effects if the dataset is change.
          else :
            src_add_next = devide_en_again(src_list[index] + src_list[index+1])
            # case 3 : src_add_next is as long as tgt_add_next , finished.
            if len(src_add_next) == len(tgt_add_next):
              new_src_list += src_add_next
              new_tgt_list += tgt_add_next
              split_again +=2
              index += 2

            # case 4 : sentence will not be used.
            else :
              not_use += 1
              # if to_do == 1 :
              #   print(index,src_add_next,tgt_add_next,len(src_add_next),len(tgt_add_next))
              index += 1

        # if it is the last one
        else :
          not_use += 1
          index += 1
    # print information
    print(f"The original total number of line is {index}.")
    print(f"The number of line pairs have the equal sentences is {same}.")
    print(f"The number of line pairs have the equal sentences after combine the next lines is {add_next}.")
    print(f"The number of line pairs have the equal sentences after combine the next lines"+\
       f"and resplit english lines using :; is {split_again}.")
    print(f"The number of line we don't use is {not_use}.")
    print(f"Note that {index} = {same}+{add_next}+{split_again}+{not_use}.")

    # remove long lines
    print(f"The total number of sentence pairs before remove long sentences is {len(new_src_list)}.")
    new_src_list,new_tgt_list,too_long_src,too_long_tgt = remove_too_long(new_src_list,new_tgt_list)
    print(f"The finally total number of sentence pairs using is {len(new_src_list)}.")
    print(f"Note that {len(new_src_list)} are the number of sentence pairs, not line pairs")

    return(new_src_list,new_tgt_list)

load and clean data
------

In [None]:
# load and clean data
def load_file(path,function):
    with open(path, "r") as f:
      data = f.read()
      return function(data)
# saving to new path
def clean_data_and_save(
    path_doc = setting["data_info"]["document"],
    raw_src_path = setting["data_info"]["source"]["raw_data_path"],
    raw_tgt_path = setting["data_info"]["target"]["raw_data_path"],
    clean_src_path = setting["data_info"]["source"]["clean_data_path"],
    clean_tgt_path = setting["data_info"]["target"]["clean_data_path"]
    ):
    raw_src_path = path_doc + raw_src_path
    raw_tgt_path = path_doc + raw_tgt_path
    src = load_file(raw_src_path,clean_s_en),
    tgt = load_file(raw_tgt_path,clean_s_zh),
    # src , tgt are tuples with only one term : src_list, tgt_list
    src_list = src[0]
    tgt_list = tgt[0]
    clean_src_list, clean_tgt_list = check_data_pairs(src_list,tgt_list)
    with open(path_doc + clean_src_path, "w") as f:
      f.write("\n".join(clean_src_list))
    with open(path_doc + clean_tgt_path, "w") as f:
      f.write("\n".join(clean_tgt_list))
# clean_data_and_save()

tokenize
------
>using sentencepiece to tokenize sentences,  
>first make the english/chinese dictionary separately,  
>then use these dict to encode sentence pair in dataset,  
>including add bos/eos/padding to tokenized sentences.  
>Finally split then into train/val set and save.

In [108]:
import sentencepiece as spm
import numpy as np
import torch.utils.data as data
def tokenized(clean_data_path,
       vocab_size,
       lang,
       tokenized_setting
       ):
  model_prefix = f"spm_{vocab_size}_{lang}"
  spm.SentencePieceTrainer.train(
      input=clean_data_path,
      **tokenized_setting,
      model_prefix=model_prefix,
  )
  return(model_prefix)
def bos_eos_padding(dataset,
          max_l,
          vocab_size,
          src_lang,
          tgt_lang
          ):

  s_src = spm.SentencePieceProcessor(model_file = path_doc + f"/spm_{vocab_size}_{src_lang}" +".model")
  s_tgt = spm.SentencePieceProcessor(model_file = path_doc + f"/spm_{vocab_size}_{tgt_lang}" +".model")
  padding_src = []
  padding_tgt = []
  len_s = 0
  len_t = 0
  for src,tgt in dataset:
    s = s_src.encode(src, out_type=int)
    s = np.append(s,[3])
    s = np.append([2],np.pad(s,(0, max_l-len(s)), constant_values = 0))
    padding_src.append(s)

    t = s_tgt.encode(tgt, out_type=int)
    t = np.append(t,[3])
    t = np.append([2],np.pad(t,(0, max_l-len(t)), constant_values = 0))
    padding_tgt.append(t)

  return(list(zip(padding_src,padding_tgt)))
# test SentencePieceProcessor and bos_eos_padding
# s_src = spm.SentencePieceProcessor(model_file="/content/spm8000_en.model")
# s_src.encode("hello world!", out_type=int)
# bos_eos_padding([("hello world","_哈囉")],5,10)

def data_set_preparing(path_doc,
            clean_src_path,
            clean_tgt_path,
            max_l,
            vocab_size,
            src_lang,
            tgt_lang,
            st_train_path,
            st_val_path,
            tt_train_path,
            tt_val_path,
            ):
    src_set = []
    tgt_set = []

    with open(path_doc+clean_src_path,"r") as in_f :
      for line in in_f:
        src_set.append(line)
    with open(path_doc+clean_tgt_path,"r") as in_f :
      for line in in_f:
        tgt_set.append(line)

    dataset = list(zip(src_set,tgt_set))
    dataset = bos_eos_padding(dataset,max_l,vocab_size,src_lang,tgt_lang)
    train_set, valid_set = data.random_split(dataset,[0.9,0.1])
    # print(train_set[0][0])

    with open(path_doc + st_train_path, 'w') as out_f:
      for line_pair in train_set:
        out_f.write(" ".join(str(x) for x in line_pair[0])+"\n")
    with open(path_doc + st_val_path, 'w') as out_f:
      for line_pair in valid_set:
        out_f.write(" ".join(str(x) for x in line_pair[0])+"\n")
    with open(path_doc + tt_train_path, 'w') as out_f:
      for line_pair in train_set:
        out_f.write(" ".join(str(x) for x in line_pair[1])+"\n")
    with open(path_doc + tt_val_path, 'w') as out_f:
      for line_pair in valid_set:
        out_f.write(" ".join(str(x) for x in line_pair[1])+"\n")

In [109]:
def tokenized_data(vocab_size = setting["tokenized_setting"]["vocab_size"],
    tokenized_setting = {k:setting["tokenized_setting"][k] for k in \
               set(list(setting["tokenized_setting"].keys()))-{"vocab_size","max_l"}},
    max_l = setting["tokenized_setting"]["max_l"],
    path_doc = setting["data_info"]["document"],
    clean_src_path = setting["data_info"]["source"]["clean_data_path"],
    clean_tgt_path = setting["data_info"]["target"]["clean_data_path"],
    src_lang = setting["data_info"]["source"]["lang"],
    tgt_lang = setting["data_info"]["target"]["lang"],
    st_train_path = setting["data_info"]["source"]["tokenized_train_data"],
    st_val_path = setting["data_info"]["source"]["tokenized_val_data"],
    tt_train_path = setting["data_info"]["target"]["tokenized_train_data"],
    tt_val_path = setting["data_info"]["target"]["tokenized_val_data"],
    ):
  tokenized(path_doc + clean_src_path,vocab_size,src_lang,tokenized_setting)
  tokenized(path_doc + clean_tgt_path,vocab_size,tgt_lang,tokenized_setting)
  data_set_preparing(path_doc,clean_src_path,clean_tgt_path,max_l,vocab_size,src_lang,
            tgt_lang,st_train_path,st_val_path,tt_train_path,tt_val_path,)
tokenized_data()

make data set
------
> Using tokenized data to make dataset.  
> Classmethod : padding_mask_batch which  
> where the key padding mask is constucted  
> also defined here.

In [115]:
import torch
from tqdm import tqdm
import numpy as np
from torch.utils.data import Dataset

class myDataset(Dataset):
  def __init__(self,src_path,tgt_path):

    self.src_path = src_path
    self.tgt_path = tgt_path

    src_list = []
    with open(self.src_path,"r") as f :
      d_l = f.readlines()
      for line in tqdm(d_l):
        int_list = [int(i) for i in line.split()]
        src_list.append(int_list)
    self.src = torch.LongTensor(src_list)

    tgt_list = []
    with open(self.tgt_path,"r") as f :
      l_l = f.readlines()
      for line in tqdm(l_l):
        int_list = [int(i) for i in line.split()]
        tgt_list.append(int_list)
    self.tgt = torch.LongTensor(np.array(tgt_list))

  def __len__(self):
    return len(self.src)

  def __getitem__(self, index):
    return self.src, self.tgt

  # make key padding mask
  @classmethod
  def padding_mask_batch(cls,batch,pad_id = setting["tokenized_setting"]["pad_id"]):
    """Collate a batch of data."""
    src, tgt = zip(*batch)
    src = torch.stack(src)
    tgt = torch.stack(tgt)
    src_padding = (src == pad_id)
    tgt_padding = (tgt == pad_id)

    return src, tgt , src_padding, tgt_padding
# test myDataset
# data = []
# with open("/content/train_dev/tokenized_train_data.txt","r") as f :
#   d_l = f.readlines()
#   for line in tqdm(d_l):
#     int_list = [int(i) for i in line.split()]
#     data.append(int_list)

In [116]:
from torch.utils.data import DataLoader
def get_data_set(batch_size = setting["training_hparas"]["batch_size"],
         path_doc = setting["data_info"]["document"],
         st_train_path = setting["data_info"]["source"]["tokenized_train_data"],
         st_val_path = setting["data_info"]["source"]["tokenized_val_data"],
         tt_train_path = setting["data_info"]["target"]["tokenized_train_data"],
         tt_val_path = setting["data_info"]["target"]["tokenized_val_data"],
         ):

  train_set = myDataset(src_path = path_doc + st_train_path,
              tgt_path = path_doc + tt_train_path,
              )
  valid_set = myDataset(src_path = path_doc + st_val_path,
              tgt_path = path_doc + tt_val_path,
              )
  train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    collate_fn=myDataset.padding_mask_batch
  )
  valid_loader = DataLoader(
    valid_set,
    batch_size=batch_size,
    num_workers=8,
    pin_memory=True,
    collate_fn=myDataset.padding_mask_batch
  )
  return train_loader,valid_loader
get_data_set()

100%|██████████| 349149/349149 [00:38<00:00, 9155.13it/s] 
100%|██████████| 349149/349149 [00:36<00:00, 9642.77it/s] 
100%|██████████| 38794/38794 [00:03<00:00, 10204.98it/s]
100%|██████████| 38794/38794 [00:03<00:00, 12823.07it/s]


(<torch.utils.data.dataloader.DataLoader at 0x7be67e044280>,
 <torch.utils.data.dataloader.DataLoader at 0x7be67e045f30>)

make model
======
positional encoding layer
------
>The first layer is embedding layer, where each integers  
>in encoder sentence will be represent by a vector.   
>I use build-in class in pytorch to finish these part,    
>and combine it with encoder layers to form my encoder.

>The layer below is the second layer :positional encoding layer  
>in this layer the position infomation is add to each "word"  
>in the sentence.
>Here I use parameters instead of constant as  
>position infomation so they will change during training process.

In [117]:
import torch
import torch.nn as nn
class Positional_Encoding(nn.Module):
    def __init__(self,max_sentence_length,embedding_dimension):
      super().__init__()
      self.dropout = nn.Dropout(0.1)
      self.encoding_values = nn.Parameter(nn.init.normal_(torch.empty(max_sentence_length,1, embedding_dimension)))
    def forward(self, x):
        # the shape of x : [batch,length,e_dim]
        # the shape of self.encoding_values : [batch,length,e_dim]
        x = x + self.encoding_values.unsqueeze(0)
        x = x.squeeze(-2)
        return self.dropout(x)

multihead attention layer
------


In [118]:
import torch.nn.functional as F
import torchvision
import math
from torchinfo import summary

def attn_mask(input_dim):
    return nn.Transformer.generate_square_subsequent_mask(input_dim)
#attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


# This part is modify from pytorch : torch.nn.functional.scaled_dot_product_attention
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, padding_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
    # Efficient implementation equivalent to the following:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)

    if is_causal:
      # assert attn_mask is None
      temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
      attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
      attn_bias.to(query.dtype)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias

    if padding_mask is not None:
        if padding_mask.dtype == torch.bool:
          padding_mask = torch.zeros_like(padding_mask,dtype = float).masked_fill_(padding_mask, (float("-inf")))

        padding_mask = padding_mask.unsqueeze(0).unsqueeze(0)
        padding_mask.to(query.dtype)

        attn_weight = attn_weight.transpose(-4,-2)
        attn_weight += padding_mask
        attn_weight = attn_weight.transpose(-4,-2)

    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value
# test scaled_dot_product_attention
# t = torch.rand([2,3,4,5])
# mask = torch.tensor([[False,False,True,True],[False,True,False,True]],dtype = torch.bool)
# print(scaled_dot_product_attention(t,t,t,padding_mask= mask, is_causal=True))
# from torch.nn.functional import scaled_dot_product_attention
class My_MultiHeadedAttention(nn.Module):
    def __init__(self, kv_input_dimension, embedding_dimension, num_heads, dropout=0.1, if_decoder = False):
        '''
        embedding_dimension = input dimension
        note that there are residual sublayers in MultiHeadedAttention
        '''
        super().__init__()
        assert embedding_dimension % num_heads == 0, "embed_dim must be divisible by num_heads"
        self.kv_d = kv_input_dimension
        self.d = embedding_dimension
        self.num_heads = num_heads
        self.is_causal = if_decoder
        self.linear_for_q = nn.Linear(self.d, self.d)
        self.linear_for_kv = nn.Linear(self.kv_d, 2 * self.d)
        self.linear_out_project = nn.Linear(self.d, self.d)

    def forward(self, q_input_data, kv_input_data , padding_mask = None):

        query = self.linear_for_q(q_input_data)
        key, value = self.linear_for_kv(kv_input_data).split(self.d,dim = -1)

        query,key,value = \
          map(lambda x : x.view(x.size(0),x.size(1),self.num_heads,self.d//self.num_heads),[query,key,value])
        query,key,value = \
          map(lambda x : x.transpose(-2,-3),[query,key,value])

        x = scaled_dot_product_attention(query,key,value,padding_mask = padding_mask, dropout_p =0, is_causal = self.is_causal)
        x = x.transpose(-2,-3).contiguous()
        x = x.view(x.size(0),x.size(1),self.d)
        x = self.linear_out_project(x)

        return x
# test My_MultiHeadedAttention
# model = My_MultiHeadedAttention(64,128,2)
# q_input = torch.rand(32,400,128)
# kv_input = torch.rand(32,400,64)
# mask = (torch.FloatTensor(32,400).uniform_() > 0.8)
# print(model(q_input,kv_input,mask).size())
# print(summary(model,q_input_data = q_input, kv_input_data = kv_input,padding_mask = mask))

encoder layer(s)
------

In [119]:
import math
class My_Encoder_Layer(nn.Module):
  def __init__(self,embedding_dimension,feedforward_dimension):
    super().__init__()
    self.emb_dim = embedding_dimension
    self.fwd_dim = feedforward_dimension

    self.attention = My_MultiHeadedAttention(self.emb_dim, self.emb_dim ,num_heads = 2, dropout=0)
    self.layer_norm_attn = nn.LayerNorm(self.emb_dim)
    self.drop_out_attn_layernorm = nn.Dropout(0)

    self.feedforward = nn.Sequential(
    nn.Linear(self.emb_dim,self.fwd_dim),
    nn.ReLU(),
    nn.Linear(self.fwd_dim,self.emb_dim)
    )
    self.layer_norm_feedforward = nn.LayerNorm(self.emb_dim)
    self.drop_out_feedforward_layernorm = nn.Dropout(0)


  def forward(self,x,padding_mask):
    x = x + self.attention(x,x,padding_mask)
    x = self.layer_norm_attn(x)

    x = self.drop_out_attn_layernorm(x)

    x = x + self.feedforward(x)
    x = self.layer_norm_feedforward(x)
    x = self.drop_out_feedforward_layernorm(x)

    return x
# test My_Encoder_Layer
# model = My_Encoder_Layer(128,256)
# input = torch.rand((32,400,128))
# mask = (torch.FloatTensor(32,400).uniform_() > 0.8)
# print(model(input,mask).size())
# print(summary(model,input_data = input,padding_mask = mask))
# print(model.state_dict().keys())
class My_Encoder(nn.Module):
  def __init__(self,max_sentence_length,dictionary_length,embedding_dimension,feedforward_dimension,layer_num = 2):
    super().__init__()
    self.max_l = max_sentence_length
    self.dict_l = dictionary_length
    self.emb_dim = embedding_dimension
    self.fwd_dim = feedforward_dimension
    self.encoder_embedding = nn.Embedding(self.dict_l,self.emb_dim,padding_idx=0)
    self.positional_encoding = Positional_Encoding(self.max_l,self.emb_dim)
    self.encoder = nn.ModuleList([My_Encoder_Layer(self.emb_dim,self.fwd_dim) for i in range(layer_num)])

  def forward(self,x,padding_mask):
    x = self.encoder_embedding(x)* math.sqrt(self.emb_dim)
    x = self.positional_encoding(x)

    for index,module in enumerate(self.encoder):
      if index == 0:
        x = module(x,padding_mask)
      else:
        x = module(x,None)
    return x
# test My_Encoder
# model = My_Encoder(400,8000,128,256)
# input = torch.randint(0,7999,(32,400,1),dtype = torch.long)
# mask = (torch.FloatTensor(32,400).uniform_() > 0.8)
# print(model(input,mask).size())
# print(summary(model,input_data = input,padding_mask = mask))
# print(model.state_dict().keys())

decoder layer(s)
------

In [120]:
import math
class My_Decoder_Layer(nn.Module):
  def __init__(self,encoder_embedding_dimension,embedding_dimension,feedforward_dimension):
    super().__init__()
    self.encoder_dim = encoder_embedding_dimension
    self.emb_dim = embedding_dimension
    self.fwd_dim = feedforward_dimension

    self.self_attention = My_MultiHeadedAttention \
     (self.emb_dim,self.emb_dim, num_heads = 2, dropout=0, if_decoder = True)
    self.layer_norm_sa = nn.LayerNorm(self.emb_dim)
    self.drop_out_sa = nn.Dropout(0)

    self.feedforward_sa = nn.Sequential(
    nn.Linear(self.emb_dim,self.fwd_dim),
    nn.ReLU(),
    nn.Linear(self.fwd_dim,self.emb_dim)
    )
    self.layer_norm_sa_fw = nn.LayerNorm(self.emb_dim)
    self.drop_out_sa_fw = nn.Dropout(0)

    self.cross_attention = My_MultiHeadedAttention \
     (self.encoder_dim, self.emb_dim, num_heads = 2, dropout=0, if_decoder = True)
    self.layer_norm_ca = nn.LayerNorm(self.emb_dim)
    self.drop_out_ca = nn.Dropout(0)

    self.feedforward_ca = nn.Sequential(
    nn.Linear(self.emb_dim,self.fwd_dim),
    nn.ReLU(),
    nn.Linear(self.fwd_dim,self.emb_dim)
    )
    self.layer_norm_ca_fw = nn.LayerNorm(self.emb_dim)
    self.drop_out_ca_fw = nn.Dropout(0)

  def forward(self,encoder_input,input,padding_mask):

    x = input + self.self_attention(input,input,padding_mask)
    x = self.layer_norm_sa(x)
    x = self.drop_out_sa(x)

    x = x + self.feedforward_sa(x)
    x = self.layer_norm_sa_fw(x)
    x = self.drop_out_sa_fw(x)

    x = x + self.cross_attention(x,encoder_input,padding_mask)
    x = self.layer_norm_ca(x)
    x = self.drop_out_ca(x)

    x = x + self.feedforward_ca(x)
    x = self.layer_norm_ca_fw(x)
    x = self.drop_out_ca_fw(x)

    return x
class My_Decoder(nn.Module):
  def __init__(self,max_sentence_length,dictionary_length,encoder_embedding_dimension,\
               embedding_dimension,feedforward_dimension,layer_num = 2):
    super().__init__()
    self.max_l = max_sentence_length
    self.dict_l = dictionary_length
    self.encoder_dim = encoder_embedding_dimension
    self.emb_dim = embedding_dimension
    self.fwd_dim = feedforward_dimension
    self.decoder_embedding = nn.Embedding(self.dict_l,self.emb_dim,padding_idx=0)
    self.positional_encoding = Positional_Encoding(self.max_l,self.emb_dim)
    self.decoder = nn.ModuleList([My_Decoder_Layer(self.encoder_dim,self.emb_dim,self.fwd_dim) for i in range(layer_num)])
    # self.encoder = My_Encoder_Layer(self.emb_dim,self.fwd_dim)

    self.generator = nn.Linear(self.emb_dim,self.dict_l)

  def forward(self,encoder_input,input,padding_mask):
    x = self.decoder_embedding(input)* math.sqrt(self.emb_dim)
    x = self.positional_encoding(x)
    # x = self.encoder(x,padding_mask)
    for index,module in enumerate(self.decoder):
      if index == 0:
        x = module(encoder_input,x,padding_mask)
      else:
        x = module(encoder_input,x,None)
    x = self.generator(x)
    x = F.log_softmax(x,dim = -1)
    return x
# test My_Decoder
# model = My_Decoder(400,8000,128,64,256)
# encoder_input = torch.rand(32,400,128)
# input = torch.randint(0,7999,(32,400,1),dtype = torch.long)
# mask = (torch.FloatTensor(32,400).uniform_() > 0.8)
# print(model(encoder_input = encoder_input,input = input, padding_mask = mask).size())
# print(summary(model,encoder_input = encoder_input,input = input, padding_mask = mask))
# print(model.state_dict().keys())

transformer layer
------

In [121]:
class My_Transformer(nn.Module):
  def __init__(self,max_sentence_length,dictionary_length,\
               encoder_embedding_dimension,decoder_embedding_dimension,feedforward_dimension,layer_num = 2):
    super().__init__()
    self.max_l = max_sentence_length
    self.dict_l = dictionary_length
    self.en_dim = encoder_embedding_dimension
    self.de_dim = decoder_embedding_dimension
    self.fw_dim = feedforward_dimension
    self.layer_num = layer_num
    self.encoder = My_Encoder \
     (self.max_l,self.dict_l,self.en_dim,self.fw_dim,self.layer_num)
    self.decoder = My_Decoder \
     (self.max_l,self.dict_l,self.en_dim,self.de_dim,self.fw_dim,self.layer_num)

  def forward(self,src,tgt,src_mask,tgt_mask):
    memory = self.encoder(src,src_mask)
    outputs = self.decoder(memory,tgt,tgt_mask)
    return outputs
# test My_Transformer
# model = My_Transformer(400,8000,128,64,256,3)
# src = torch.randint(0,8000,(32,400,1),dtype = torch.long)
# tgt = torch.randint(0,8000,(32,400,1),dtype = torch.long)
# src_mask = torch.cat(((torch.FloatTensor(32,200).uniform_() > 1),(torch.FloatTensor(32,200).uniform_() > 0.15)),dim =1)
# tgt_mask = torch.cat(((torch.FloatTensor(32,100).uniform_() > 1),(torch.FloatTensor(32,300).uniform_() > 0.15)),dim =1)
# out = model(src,tgt,src_mask,tgt_mask)
# print(out.size(),out.dim(),out[0][0])
# print(summary(model,src = src,tgt = tgt,src_mask = src_mask,tgt_mask = tgt_mask))
# print(model.state_dict().keys())

training and validation
======
label smoothing
------

In [122]:
import gc
class LabelSmoothedCrossEntropyCriterion(nn.Module):
  def __init__(self,dictionary_length,smoothing,padding_id):
        super().__init__()
        self.dict_len = dictionary_length
        self.smoothing = smoothing
        self.padding_id = padding_id
  def forward(self, outputs , label):

    label_onehot = F.one_hot(label,self.dict_len).float().squeeze()
    add = self.smoothing / (self.dict_len)
    label_onehot += add

    label_smoothed = label_onehot.masked_fill_((label_onehot > 1),float(1-self.smoothing+add))
    loss = -1*torch.sum((outputs*label_smoothed),dim = -1)

    label_padding_mask = (label == self.padding_id).squeeze()
    mask_loss = loss.masked_fill_(label_padding_mask,0)

    avg_loss = torch.mean(mask_loss)

    del label,label_onehot,label_smoothed,label_padding_mask,loss,mask_loss
    gc.collect()
    return(avg_loss)
# ignore_index not work correctly
# def LabelSmoothedCrossEntropy(outputs , label,dictionary_length,smooth,padding_id):
#   print(outputs.shape)
#   print(label.shape)
#   label_onehot = label.transpose(-1,-2).squeeze()
#   outputs = outputs.transpose(-1,-2)
#   cal_loss = nn.CrossEntropyLoss(ignore_index = padding_id,reduction = "mean", label_smoothing=smooth)
#   return cal_loss(outputs,label_onehot)

# test LabelSmoothedCrossEntropyCriterion
# cal1 = LabelSmoothedCrossEntropyCriterion(8000,0.1,0)
# print(list(iter(cal1.state_dict())))
# cal2 = LabelSmoothedCrossEntropy(out,tgt,8000,0.1,0)
# print(cal1(out,tgt),cal2)

In [None]:
# see https://arxiv.org/pdf/1512.00567.pdf page 7

#Ref 1 : Hong-Yi Li ML2021 HW5

# class LabelSmoothedCrossEntropyCriterion(nn.Module):
#     def __init__(self, smoothing, ignore_index=None, reduce=True):
#         super().__init__()
#         self.smoothing = smoothing
#         self.ignore_index = ignore_index
#         self.reduce = reduce

#     def forward(self, lprobs, target):
#         if target.dim() == lprobs.dim() - 1:
#             target = target.unsqueeze(-1)
#         # nll: Negative log likelihood，the cross-entropy when target is one-hot. following line is same as F.nll_loss
#         nll_loss = -lprobs.gather(dim=-1, index=target)
#         #  reserve some probability for other labels. thus when calculating cross-entropy,
#         # equivalent to summing the log probs of all labels
#         smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
#         if self.ignore_index is not None:
#             pad_mask = target.eq(self.ignore_index)
#             nll_loss.masked_fill_(pad_mask, 0.0)
#             smooth_loss.masked_fill_(pad_mask, 0.0)
#         else:
#             nll_loss = nll_loss.squeeze(-1)
#             smooth_loss = smooth_loss.squeeze(-1)
#         if self.reduce:
#             nll_loss = nll_loss.sum()
#             smooth_loss = smooth_loss.sum()
#         # when calculating cross-entropy, add the loss of other labels
#         eps_i = self.smoothing / lprobs.size(-1)
#         loss = (1.0 - self.smoothing) * nll_loss + eps_i * smooth_loss
#         return loss

#Ref 2 : By hemingkx : https://github.com/hemingkx/ChineseNMT

# class LabelSmoothing(nn.Module):
#     """Implement label smoothing."""

#     def __init__(self, size, padding_idx, smoothing=0.0):
#         super(LabelSmoothing, self).__init__()
#         self.criterion = nn.KLDivLoss(size_average=False)
#         self.padding_idx = padding_idx
#         self.confidence = 1.0 - smoothing
#         self.smoothing = smoothing
#         self.size = size
#         self.true_dist = None


#     def forward(self, x, target):
#         assert x.size(1) == self.size
#         true_dist = x.data.clone()
#         true_dist.fill_(self.smoothing / (self.size - 2))
#         true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
#         true_dist[:, self.padding_idx] = 0
#         mask = torch.nonzero(target.data == self.padding_idx)
#         if mask.dim() > 0:
#             true_dist.index_fill_(0, mask.squeeze(), 0.0)
#         self.true_dist = true_dist
#         return self.criterion(x, Variable(true_dist, requires_grad=False))

beam search
------

In [123]:
import gc
def beam_search_one_step(sentences,p_sentences,n_beam_output):
    # sentences : {type : tensor , shape : (batch X beam_num) X now_sentences_length X 1 value : int}
    # p_sentences : {type : tensor , shape : batch X beam_num X 1 value : log_softmax probability}
    # n_beam_output : {type : tensor , shape : batch X beam_num X dictionary_length,
    # value : [P1,P2,P3...] X beam_num times (Pk in [0,1])}

    '''
    TO DO : (set beam num = K)
    for every batch:
    expand sentences(total number = K) K times (so there are K-square sentences),then concat with
    the index of top K consequence of each beam(total K beams) in n_beam_output (so there are also K-square values).
    '''
    batch = n_beam_output.size(0)
    beam_num = n_beam_output.size(1)
    # sentences : {type : tensor , shape : batch X beam_num X now_sentences_length value : int}
    sentences = sentences.view(batch,beam_num,-1)
    # repeat : {type : tensor , shape : beam_num X 1 ,value : beam_num}
    # each row repeat beam_num times before concatenate
    repeat = torch.full([beam_num],fill_value = beam_num)
    # sentences_expand : {type : tensor , shape : batch X (beam_num X beam_num) X now_sentences_length ,
    # value : [[[A,B...] X beam_num times,[C,D...] X beam_num times}...] A,B,C,D...are int}
    sentences_expand = torch.repeat_interleave(sentences,repeat,dim=1)
    del repeat
    # topk_prob : {type : tensor , shape : batch X beam_num X beam_num, value : log_softmax probability}
    # topk_index : {type : tensor , shape : batch X beam_num X beam_num, value : int}
    topk_prob, topk_index = torch.topk(n_beam_output,dim = -1,k = beam_num)

    # topk_index : {type : tensor , shape : batch X (beam_num X beam_num) X 1, value : int}
    topk_index = topk_index.view(batch,-1,1)
    # sentences : {type : tensor , shape : batch X (beam_num X beam_num) X (now_sentences_length+1), value : int}
    sentences_expand = torch.cat((sentences_expand,topk_index),dim = -1)
    '''
    TO DO :
    multipies p_sentences with the probability of top K consequence of each beam(total K beams) in n_beam_output
    (so there are also K-square values).

    The final step is to choose Top K consequence from K-square sentences by using p_sentences.
    '''

    # p_sentences : {type : tensor , shape : batch X (beam_num X beam_num),
    # value : [P1,P2,P3...] X beam_num times (Pk is log_softmax probability)}
    p_sentences = (p_sentences+topk_prob).view(batch,-1)
    # p_sentences : {type : tensor , shape : batch X beam_num, value : log_softmax probability}
    # p_index : {type : tensor , shape : batch X beam_num, value : int}
    p_sentences, p_index = torch.topk(p_sentences, dim = 1, k = beam_num)
    p_sentences = p_sentences.unsqueeze(-1)
    # row : {type : tensor , shape : batch X 1, value : [[0],[1],[2],...]}
    row = torch.tensor(range(batch)).unsqueeze(1)
    # sentences : {type : tensor , shape : batch X beam_num X (now_sentences_length+1), value : log_softmax probability}
    new_sentences = sentences_expand[row, p_index].view(batch*beam_num,-1)
    sentences.data = new_sentences.data

    del topk_prob,topk_index,sentences_expand,row,new_sentences
    gc.collect()
    return sentences,p_sentences

def get_next_word(model,memory,out,out_probability,id,batch,beam_num,max_sentence_length,dictionary_length,padding_id):
    # padding : {type : tensor , shape : (Batch X beam_num) X (max_sentence_length-(id+1)) ,value : int}
    padding = torch.full(size = (batch*beam_num,max_sentence_length-(id+1)),fill_value = padding_id)
    # out_padding : {type : tensor , shape : (Batch X beam_num) X max_sentence_length X 1,
    # value : [[bos_id],[any_id],...[padding_id],....] X Batch}
    out_padding = torch.cat((out,padding),dim = 1).unsqueeze(-1)
    # tgt_padding : {type : tensor , shape : (Batch X beam_num) X max_sentence_length ,value: bool}
    tgt_padding = (out_padding == padding_id).squeeze(-1)
    # out_add : {type : tensor , shape : Batch X beam_num X dictionary_length ,value : int}
    out_add = model.decoder(memory,out_padding,tgt_padding)[:,id,:].view(batch,beam_num,dictionary_length)

    # out_n_beam : {type : tensor , shape : (Batch X beam_num) X (id+1) ,value : int}
    # out_probability {type : tensor , shape : Batch X beam_num X 1 , value : log_softmax probability}
    out , out_probability = beam_search_one_step(out,out_probability,out_add)

    del padding,out_padding,tgt_padding,out_add
    gc.collect()
    return(out , out_probability)

def apply_beam_search(model,src,src_mask,beam_num,
               max_sentence_length,
               dictionary_length,
               bos_id = setting["tokenized_setting"]["bos_id"],
               padding_id = setting["tokenized_setting"]["pad_id"]):
    with torch.no_grad():
      batch = src.size(0)
      # out : {type : tensor , shape : Batch X 1 , value : bos_id}
      out = torch.full(size = (batch,1),fill_value = bos_id)
      # repeat : {type : tensor , shape : Batch X 1 ,value : beam_num}
      # each row repeat beam_num times before concatenate
      repeat = torch.full([batch],fill_value = beam_num)
      # out_beam_expand : {type : tensor , shape : (Batch X beam_num) X 1 ,value : bos_id}
      out_beam_expand = torch.repeat_interleave(out,repeat,dim=0)

      del out

      # out_probability {type : tensor , shape : Batch X beam_num X 1, value : 0.1}
      out_probability = torch.full(size = (batch,beam_num,1),fill_value = 0.0)

      # memory : {type : tensor , shape : Batch X max_sentence_length X encoder_output_dim ,value : arbitary float}
      memory = model.encoder(src,src_mask)
      # print(f"memor.shape={memory.shape}")
      # memory_beam_expand : {type : tensor , shape : (Batch X n_beam) X max_sentence_length X encoder_output_dim ,value : float}
      memory_beam_expand = torch.repeat_interleave(memory,repeat,dim=0)

      del memory,src,src_mask,repeat

      gc.collect()

      for id in range(max_sentence_length-1):

        # out_n_beam : {type : tensor , shape : (Batch X beam_num) X (id+1) ,value : int}
        # out_probability {type : tensor , shape : Batch X beam_num X 1 , value : log_softmax probability}
        new_out_beam_expand , new_out_probability = \
        get_next_word(model,memory_beam_expand,out_beam_expand,
               out_probability,id,batch,beam_num,max_sentence_length,dictionary_length,padding_id)

        del out_beam_expand,out_probability

        out_beam_expand,out_probability = new_out_beam_expand , new_out_probability

        del new_out_beam_expand , new_out_probability

        gc.collect()

      # out_beam_expand : {type : tensor , shape : Batch X beam_num X (max_sentence_length) ,value : 0 or 1}
      print(out_beam_expand.shape)
      out_beam_expand = out_beam_expand.view(batch,beam_num,max_sentence_length)
      # max_probability : {type : tensor , shape :  Batch  X 1 ,value : int(max prob index)}
      max_probability = torch.argmax(input = out_probability,dim = 1)
      # max_probability_expand : {type : tensor , shape :  Batch  X 1 X max_sentence_length ,
      # value : [[A,A,A....],[B,B,B...],...](A,B are int)}
      max_probability_expand = max_probability.expand(batch, max_sentence_length).unsqueeze(1)
      print(max_probability_expand.shape)
      print(out_beam_expand[0])
      # out : {type : tensor , shape :  Batch X max_sentence_length ,value : [[int,int,...],[int,int...],...]}
      out =  torch.gather(input = out_beam_expand ,dim = 1, index = max_probability_expand)
      print(out.shape)
    return out

In [None]:
# #test
# batch = 3
# beam_num = 2
# sentences = torch.randint(0,8000,(batch*beam_num,5))
# p_sentences = torch.log(torch.rand((batch , beam_num , 1)))
# n_beam_output = torch.rand((batch , beam_num , 8000))
# print(sentences,p_sentences,n_beam_output)
# print(beam_search_one_step(sentences,p_sentences,n_beam_output))
# repeat = torch.full([beam_num],fill_value = beam_num)
# sentences_expand = torch.repeat_interleave(sentences.view(batch,beam_num,-1),repeat,dim=1)
# print(sentences_expand,sentences_expand.shape)

# topk_prob = torch.rand((batch , beam_num , beam_num))
# print(p_sentences)
# print(topk_prob)
# prob = (p_sentences*topk_prob).view(batch,-1)
# print(prob,prob.shape)
# row = torch.tensor(range(batch))
# id = torch.tensor([[[0],[1]],[[1],[0]],[[0],[1]]]).unsqueeze(1)

# print(row)
# print(id.shape)

In [None]:
test = apply_beam_search(model,src,src_mask,2,
               400,
               8000,
              )
print(test)

0
40
80
120
160
200
240
280
320
360
torch.Size([64, 400])
torch.Size([32, 1, 400])
tensor([[   2, 2163, 5610,  711, 6892, 2027,  421, 3564,  484, 5813, 4662, 5009,
         6038,  593, 4666, 1090, 2486, 7320, 1225, 1674, 5813, 5233,  833, 3135,
         7320, 1225,  407, 6046, 3405,  593, 3662, 2358, 4690, 4455, 6470, 1714,
          760, 1621, 6107, 5813, 5233, 3904, 1094, 6973, 5745, 5080, 3255,  889,
         7816, 5495, 5612, 4476, 3075,  833, 7635, 5073, 6318, 4940, 3793,  629,
         4571, 4576, 3387, 5934, 3865, 3037, 2772, 7263,  325, 1713, 7114, 5249,
         3386, 2769, 4914, 1237, 5156, 5392, 1318,  593, 5088, 6283, 6383, 3807,
         3055, 3231, 4020, 2741,   48, 7682, 2409, 1491, 7816, 3765, 4998, 3476,
          787,  593, 1980, 3795, 5267, 5053, 2754, 1621, 2650, 3045, 3184, 3494,
         3694, 6659, 4690, 5892, 5228,  919, 4780, 4327, 1999, 7540, 2607, 4820,
          665, 3347, 7635, 7704,  469,  884,  686, 1209, 1733,  117, 4144, 6973,
         7320, 1458, 2835,

In [None]:
# x = torch.tensor([[1, 2, 3],[4,5,6]])
# y = torch.full([2],fill_value = 3)

# print(y)
# torch.repeat_interleave(x, y, dim=0)


x = torch.tensor([[2,5,7,3],[3,4,6,5],[5,2,1,8],[6,4,3,9],[1,2,9,7],[3,4,8,2]])
p = torch.tensor([[0.9],[0.8],[0.4],[0.2],[0.9],[0.7]])
x = x.view(3,2,4)
p = p.view(3,2,1)
print(p.shape)
max_p = torch.argmax(input = p,dim = 1)
print(max_p.shape)
max_p_expand = max_p.expand(max_p.size(0), x.size(-1)).unsqueeze(1)
print(max_p_expand.shape,x.shape)
print(max_p_expand,x)
out =  torch.gather(input = x ,dim = 1, index = max_p_expand)
out
# t = torch.tensor([[1, 2], [3, 4]])
# torch.gather(t, 1, torch.tensor([[0, 0, 0], [1, 0, 0]]))

torch.Size([3, 2, 1])
torch.Size([3, 1])
torch.Size([3, 1, 4]) torch.Size([3, 2, 4])
tensor([[[0, 0, 0, 0]],

        [[0, 0, 0, 0]],

        [[0, 0, 0, 0]]]) tensor([[[2, 5, 7, 3],
         [3, 4, 6, 5]],

        [[5, 2, 1, 8],
         [6, 4, 3, 9]],

        [[1, 2, 9, 7],
         [3, 4, 8, 2]]])


tensor([[[2, 5, 7, 3]],

        [[5, 2, 1, 8]],

        [[1, 2, 9, 7]]])

In [None]:
a = torch.FloatTensor([[[1,1,1],[2,2,2]],[[9,9,9],[5,5,5]]])
b = torch.LongTensor([1,0])
print(a.shape)
R = a.shape[0]
C = a.shape[2]

idx = b.unsqueeze(dim=1).repeat(1, C).view(R, 1, C)
print(idx)
print(idx.shape)
torch.gather(a, 1, idx)

torch.Size([2, 2, 3])
tensor([[[1, 1, 1]],

        [[0, 0, 0]]])
torch.Size([2, 1, 3])


tensor([[[2., 2., 2.]],

        [[9., 9., 9.]]])

In [None]:
# label_onehot = F.one_hot(tgt,8000).float().squeeze()
# print((tgt<4000).expand(32,400,2)[0])