In [1]:
import pickle

import torch
from torch.utils.data import Dataset, DataLoader

In [2]:
with open("./data/processed/tokenized/spm_tokenized_data.pkl", "rb") as f:
    data = pickle.load(f)

In [3]:
em_formal_train = data["train"]["em_formal"]
em_informal_train = data["train"]["em_informal"]
fr_formal_train = data["train"]["fr_formal"]
fr_informal_train = data["train"]["fr_informal"]

em_formal_test = data["test"]["em_formal"]
em_informal_test = data["test"]["em_informal"]
fr_formal_test = data["test"]["fr_formal"]
fr_informal_test = data["test"]["fr_informal"]

In [4]:
max_len = 300
batch_size = 16
num_workers = 2

In [22]:
class CustomDataset(Dataset):
    def __init__(self, src_data, trg_data, min_len, max_len):
        self.src_data = src_data
        self.trg_data = trg_data
        self.min_len = min_len
        self.max_len = max_len
                
        self.total_padded_data = []
        
        for src, trg in zip(self.src_data, self.trg_data):
            if self.min_len < len(src) < self.max_len:
                self.src_padded_data = torch.zeros(self.max_len)
                self.src_padded_data[:len(src)] = torch.Tensor(src)
            if self.min_len < len(trg) < self.max_len:
                self.trg_padded_data = torch.zeros(self.max_len)
                self.trg_padded_data[:len(trg)] = torch.Tensor(trg)
                
            self.total_padded_data.append((self.src_padded_data, self.trg_padded_data))
                
        self.total_padded_data = tuple(self.total_padded_data)
        
    def __getitem__(self, idx):
        return self.total_padded_data[idx]

    
    def __len__(self):
        return len(self.total_padded_data)
    
#     def __iter__(self):
#         for x in self.data:
#             yield x
    
#     def get_vocab(self):
#         return self.vocab
    
#     def decode(self, x):
#         return self.vocab.DecodeIds(x)

In [23]:
train_data = CustomDataset(em_formal_train, em_informal_train, 2, 300)

In [25]:
train_loader = DataLoader(train_data, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=0)

In [26]:
for _, (src, trg) in enumerate(train_loader):
#     print(i)
    print("src:", src)
    print("trg:", trg)
#     break

src: tensor([[  1.,   8.,  20.,  ...,   0.,   0.,   0.],
        [  1.,  46.,  12.,  ...,   0.,   0.,   0.],
        [  1., 261.,  51.,  ...,   0.,   0.,   0.],
        ...,
        [  1.,  36.,  68.,  ...,   0.,   0.,   0.],
        [  1.,  87., 982.,  ...,   0.,   0.,   0.],
        [  1., 200., 358.,  ...,   0.,   0.,   0.]])
trg: tensor([[1.0000e+00, 1.4000e+01, 1.3020e+03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.6000e+01, 2.1000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.5000e+01, 3.1000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 1.2900e+02, 1.4710e+03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 3.8000e+01, 1.5840e+03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.5590e+03, 2.0100e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
src: tensor([[  1., 261.,  11.,  ...,   0.,   0.,   0.],
        [  

src: tensor([[1.0000e+00, 1.0500e+03, 7.0500e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.5700e+02, 1.0000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 5.6000e+01, 6.1000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 4.5400e+02, 7.1000e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 9.3600e+02, 7.7000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 4.2000e+01, 4.0000e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
trg: tensor([[1.0000e+00, 1.4000e+01, 6.5000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 2.7200e+02, 8.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.1000e+01, 1.2100e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 2.6300e+02, 1.0760e+03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000

src: tensor([[1.0000e+00, 1.9000e+02, 2.9400e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 5.6000e+01, 1.0000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.0000e+00, 4.9000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 8.0000e+00, 4.9000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.1680e+03, 4.2200e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.0000e+00, 6.5000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
trg: tensor([[1.0000e+00, 3.2300e+02, 6.3000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 2.0000e+01, 1.8000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 4.9000e+01, 2.0900e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 1.0430e+03, 1.0430e+03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000

trg: tensor([[1.0000e+00, 2.5500e+02, 9.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.4000e+01, 5.0000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.0080e+03, 7.2000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 5.3700e+02, 1.2000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.4000e+01, 2.7800e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 2.1000e+02, 9.0000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
src: tensor([[  1.,  46.,  20.,  ...,   0.,   0.,   0.],
        [  1.,  87.,  35.,  ...,   0.,   0.,   0.],
        [  1.,   8.,  43.,  ...,   0.,   0.,   0.],
        ...,
        [  1.,  36., 188.,  ...,   0.,   0.,   0.],
        [  1.,  56., 125.,  ...,   0.,   0.,   0.],
        [  1., 984.,  21.,  ...,   0.,   0.,   0.]])
trg: tensor([[  1., 969., 698.,  ...,   0.,   0.,   0.],
        [  

trg: tensor([[1.0000e+00, 1.3410e+03, 5.5000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 3.4000e+02, 9.7300e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 2.0000e+01, 1.4100e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 3.1800e+02, 1.1000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.8400e+02, 1.5000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.2800e+02, 1.7300e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
src: tensor([[1.0000e+00, 2.2300e+02, 1.2000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 5.6000e+01, 6.1000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 4.6000e+01, 3.5000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 6.3100e+02, 1.5600e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000

trg: tensor([[  1., 693.,  15.,  ...,   0.,   0.,   0.],
        [  1., 258.,  21.,  ...,   0.,   0.,   0.],
        [  1.,  92., 382.,  ...,   0.,   0.,   0.],
        ...,
        [  1., 301., 172.,  ...,   0.,   0.,   0.],
        [  1.,  47.,  23.,  ...,   0.,   0.,   0.],
        [  1., 505.,  42.,  ...,   0.,   0.,   0.]])
src: tensor([[1.0000e+00, 8.0000e+00, 4.9000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.0000e+00, 1.4630e+03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.4260e+03, 1.8000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 6.4500e+02, 2.2000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.4400e+02, 7.2000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.0000e+00, 7.0000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
trg: tensor([[1.0000e+00, 1.4000e+01, 2.6900e+02,  ..., 0.0000e+00, 

trg: tensor([[1.0000e+00, 8.8500e+02, 3.3000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.1000e+01, 1.7000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 2.7400e+02, 1.0890e+03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 2.2500e+02, 3.3000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.4100e+02, 1.4000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 2.0200e+02, 2.0000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
src: tensor([[1.0000e+00, 8.0000e+00, 1.1010e+03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 4.5500e+02, 1.2000e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 4.6000e+01, 8.7400e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 8.0000e+00, 1.5000e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000

src: tensor([[1.0000e+00, 3.7800e+02, 2.8000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.6300e+02, 1.5000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 3.6000e+01, 1.2100e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 8.0000e+00, 4.9000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 5.2900e+02, 3.0000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.4390e+03, 1.0000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
trg: tensor([[  1.,  18., 212.,  ...,   0.,   0.,   0.],
        [  1.,  65.,  20.,  ...,   0.,   0.,   0.],
        [  1.,  11., 844.,  ...,   0.,   0.,   0.],
        ...,
        [  1.,  51., 687.,  ...,   0.,   0.,   0.],
        [  1., 231.,  65.,  ...,   0.,   0.,   0.],
        [  1., 690., 216.,  ...,   0.,   0.,   0.]])
src: tensor([[1.0000e+00, 1.7300e+02, 1.7190e+03,  ..., 0.0000e+00, 

src: tensor([[1.0000e+00, 4.4000e+02, 5.2100e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.0000e+00, 1.4630e+03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.5700e+02, 1.0000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 5.6000e+01, 6.2000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 4.6000e+01, 3.5000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 4.6000e+01, 9.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
trg: tensor([[  1., 303., 496.,  ...,   0.,   0.,   0.],
        [  1., 250., 467.,  ...,   0.,   0.,   0.],
        [  1., 752.,   9.,  ...,   0.,   0.,   0.],
        ...,
        [  1., 236.,  52.,  ...,   0.,   0.,   0.],
        [  1.,  16.,  45.,  ...,   0.,   0.,   0.],
        [  1., 130., 524.,  ...,   0.,   0.,   0.]])
src: tensor([[1.0000e+00, 1.4450e+03, 1.9400e+02,  ..., 0.0000e+00, 

src: tensor([[  1., 247.,  12.,  ...,   0.,   0.,   0.],
        [  1., 144., 208.,  ...,   0.,   0.,   0.],
        [  1., 247.,  12.,  ...,   0.,   0.,   0.],
        ...,
        [  1.,  36., 141.,  ...,   0.,   0.,   0.],
        [  1.,   8.,  62.,  ...,   0.,   0.,   0.],
        [  1.,  96.,  66.,  ...,   0.,   0.,   0.]])
trg: tensor([[1.0000e+00, 7.6300e+02, 4.4000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 2.5000e+02, 2.7000e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.2300e+02, 4.1600e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 1.0180e+03, 1.1440e+03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.4000e+01, 5.2000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 7.5000e+01, 4.9000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
src: tensor([[1.0000e+00, 3.4100e+02, 3.3000e+01,  ..., 0.0000e+00, 

trg: tensor([[  1., 202.,  15.,  ...,   0.,   0.,   0.],
        [  1., 745.,  26.,  ...,   0.,   0.,   0.],
        [  1.,  69.,  12.,  ...,   0.,   0.,   0.],
        ...,
        [  1.,  15.,  52.,  ...,   0.,   0.,   0.],
        [  1.,  14.,  69.,  ...,   0.,   0.,   0.],
        [  1., 121., 720.,  ...,   0.,   0.,   0.]])
src: tensor([[  1., 582.,  17.,  ...,   0.,   0.,   0.],
        [  1., 118.,  41.,  ...,   0.,   0.,   0.],
        [  1.,  42., 327.,  ...,   0.,   0.,   0.],
        ...,
        [  1., 176., 108.,  ...,   0.,   0.,   0.],
        [  1., 591.,  24.,  ...,   0.,   0.,   0.],
        [  1.,  96.,  15.,  ...,   0.,   0.,   0.]])
trg: tensor([[  1., 316., 493.,  ...,   0.,   0.,   0.],
        [  1., 236., 249.,  ...,   0.,   0.,   0.],
        [  1., 268., 612.,  ...,   0.,   0.,   0.],
        ...,
        [  1., 348.,  70.,  ...,   0.,   0.,   0.],
        [  1.,  87., 319.,  ...,   0.,   0.,   0.],
        [  1.,  75.,  20.,  ...,   0.,   0.,   0.]])
src: te

src: tensor([[1.0000e+00, 4.6000e+01, 1.2000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.6300e+02, 1.4800e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.4580e+03, 1.7000e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 1.7430e+03, 2.4500e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 3.6000e+01, 6.1100e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 3.6000e+01, 1.4900e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
trg: tensor([[1.0000e+00, 1.6000e+01, 1.2000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 5.5200e+02, 1.7300e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 6.1400e+02, 2.0400e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 8.7000e+01, 4.3400e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000

trg: tensor([[1.0000e+00, 4.7200e+02, 7.7100e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 7.7800e+02, 4.3200e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 2.2500e+02, 1.0000e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 5.5700e+02, 1.6110e+03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 2.8600e+02, 9.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 7.5300e+02, 2.4000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
src: tensor([[  1., 144.,  10.,  ...,   0.,   0.,   0.],
        [  1., 155., 171.,  ...,   0.,   0.,   0.],
        [  1., 378., 125.,  ...,   0.,   0.,   0.],
        ...,
        [  1.,   8.,  49.,  ...,   0.,   0.,   0.],
        [  1., 127., 278.,  ...,   0.,   0.,   0.],
        [  1., 157.,  10.,  ...,   0.,   0.,   0.]])
trg: tensor([[1.0000e+00, 7.2000e+01, 9.1000e+01,  ..., 0.0000e+00, 

trg: tensor([[1.0000e+00, 1.5000e+01, 4.8800e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.2020e+03, 1.6500e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 5.4200e+02, 9.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 1.4000e+01, 6.9000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 7.7600e+02, 1.0550e+03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.4000e+01, 2.6900e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
src: tensor([[  1.,  11.,  60.,  ...,   0.,   0.,   0.],
        [  1.,   8.,  49.,  ...,   0.,   0.,   0.],
        [  1., 118., 162.,  ...,   0.,   0.,   0.],
        ...,
        [  1.,  46.,  12.,  ...,   0.,   0.,   0.],
        [  1., 975.,  52.,  ...,   0.,   0.,   0.],
        [  1., 144.,  10.,  ...,   0.,   0.,   0.]])
trg: tensor([[  1.,  11.,  66.,  ...,   0.,   0.,   0.],
        [  

trg: tensor([[  1., 152.,  71.,  ...,   0.,   0.,   0.],
        [  1.,  15., 157.,  ...,   0.,   0.,   0.],
        [  1.,  43.,  50.,  ...,   0.,   0.,   0.],
        ...,
        [  1.,  13., 781.,  ...,   0.,   0.,   0.],
        [  1.,  72.,   9.,  ...,   0.,   0.,   0.],
        [  1.,  15.,  50.,  ...,   0.,   0.,   0.]])
src: tensor([[1.0000e+00, 1.2800e+02, 1.2000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.0000e+00, 3.5000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.0190e+03, 6.4700e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 3.6000e+01, 6.8000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 6.3100e+02, 1.5600e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.9000e+02, 8.2800e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
trg: tensor([[1.0000e+00, 4.8400e+02, 1.2000e+01,  ..., 0.0000e+00, 

src: tensor([[1.0000e+00, 8.0000e+00, 3.9000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 2.8500e+02, 2.8900e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.0700e+02, 7.9000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 1.2800e+02, 1.2000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 9.6000e+01, 1.5000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 5.6000e+01, 1.0890e+03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
trg: tensor([[1.0000e+00, 2.2000e+01, 1.1200e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.0100e+02, 7.7400e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 6.0400e+02, 4.8000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 5.7000e+01, 8.6000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000

src: tensor([[1.0000e+00, 3.7800e+02, 2.3000e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 4.6000e+01, 2.5000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 2.7900e+02, 1.3000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 1.0430e+03, 1.1000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.7600e+02, 1.0800e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.0000e+00, 3.0000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
trg: tensor([[1.0000e+00, 2.1200e+02, 1.1200e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.6000e+01, 1.6700e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 3.1400e+02, 8.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 1.3300e+02, 4.8300e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000

src: tensor([[1.0000e+00, 8.7000e+01, 1.0800e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 9.6000e+01, 1.5000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.7000e+01, 1.0800e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 6.4500e+02, 2.2000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 2.7500e+02, 7.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.1680e+03, 4.2200e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
trg: tensor([[1.0000e+00, 3.8000e+01, 1.2000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.4000e+01, 7.5000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 4.9700e+02, 7.8400e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 1.2690e+03, 2.4000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000

src: tensor([[  1., 118., 162.,  ...,   0.,   0.,   0.],
        [  1.,  56., 125.,  ...,   0.,   0.,   0.],
        [  1., 209.,  12.,  ...,   0.,   0.,   0.],
        ...,
        [  1.,   8.,  50.,  ...,   0.,   0.,   0.],
        [  1.,   8., 180.,  ...,   0.,   0.,   0.],
        [  1., 157.,  10.,  ...,   0.,   0.,   0.]])
trg: tensor([[1.0000e+00, 1.1400e+02, 1.9500e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 2.0000e+01, 2.4900e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.2330e+03, 1.6600e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 1.4000e+01, 1.4720e+03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.4000e+01, 3.3700e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.4100e+02, 9.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
src: tensor([[1.0000e+00, 3.6000e+01, 1.1150e+03,  ..., 0.0000e+00, 

trg: tensor([[  1., 129., 389.,  ...,   0.,   0.,   0.],
        [  1.,  87., 666.,  ...,   0.,   0.,   0.],
        [  1., 129., 823.,  ...,   0.,   0.,   0.],
        ...,
        [  1.,  14.,  61.,  ...,   0.,   0.,   0.],
        [  1., 366., 558.,  ...,   0.,   0.,   0.],
        [  1.,  14., 223.,  ...,   0.,   0.,   0.]])
src: tensor([[1.0000e+00, 1.2800e+02, 1.2100e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.4580e+03, 2.8000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.0000e+00, 4.9700e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 1.7010e+03, 9.2200e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.0000e+00, 9.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 9.9800e+02, 7.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
trg: tensor([[1.0000e+00, 1.2900e+02, 5.4000e+01,  ..., 0.0000e+00, 

src: tensor([[1.0000e+00, 8.0000e+00, 3.1000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 5.0000e+02, 5.9000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.0590e+03, 1.7000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 7.3900e+02, 1.5000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.0000e+00, 1.6800e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.5470e+03, 1.8600e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
trg: tensor([[1.0000e+00, 1.5000e+01, 1.2000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 3.5600e+02, 5.9200e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.7100e+02, 3.6100e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 1.1060e+03, 2.0000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000

trg: tensor([[1.0000e+00, 1.4000e+01, 4.0000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 7.7800e+02, 1.5800e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 7.6000e+01, 6.4100e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 1.4000e+01, 5.0000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.0080e+03, 1.1600e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.8000e+01, 2.8600e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
src: tensor([[  1., 745.,  10.,  ...,   0.,   0.,   0.],
        [  1., 777.,   8.,  ...,   0.,   0.,   0.],
        [  1., 620.,  97.,  ...,   0.,   0.,   0.],
        ...,
        [  1.,   8., 315.,  ...,   0.,   0.,   0.],
        [  1.,   8.,  31.,  ...,   0.,   0.,   0.],
        [  1.,  56.,  61.,  ...,   0.,   0.,   0.]])
trg: tensor([[  1., 255.,  60.,  ...,   0.,   0.,   0.],
        [  

trg: tensor([[  1.,  91., 109.,  ...,   0.,   0.,   0.],
        [  1.,  51., 232.,  ...,   0.,   0.,   0.],
        [  1.,  15., 280.,  ...,   0.,   0.,   0.],
        ...,
        [  1.,  14.,  52.,  ...,   0.,   0.,   0.],
        [  1.,  14.,  50.,  ...,   0.,   0.,   0.],
        [  1., 372., 150.,  ...,   0.,   0.,   0.]])
src: tensor([[  1.,   8.,  65.,  ...,   0.,   0.,   0.],
        [  1., 223., 152.,  ...,   0.,   0.,   0.],
        [  1.,   8.,  20.,  ...,   0.,   0.,   0.],
        ...,
        [  1.,  36., 103.,  ...,   0.,   0.,   0.],
        [  1.,   8., 557.,  ...,   0.,   0.,   0.],
        [  1.,  36.,  68.,  ...,   0.,   0.,   0.]])
trg: tensor([[1.0000e+00, 1.4200e+02, 3.5000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.3000e+01, 1.2200e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 3.7500e+02, 7.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 3.9000e+

trg: tensor([[1.0000e+00, 2.0000e+01, 2.4900e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.1240e+03, 6.9700e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.4000e+01, 7.5100e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 1.5000e+01, 1.0500e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 3.7200e+02, 7.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 3.4800e+02, 4.8900e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
src: tensor([[1.0000e+00, 1.4400e+02, 1.7000e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 8.0000e+00, 5.0000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.2810e+03, 1.9900e+02,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [1.0000e+00, 8.0000e+00, 4.9000e+01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000

In [95]:
d = torch.rand(1, 16, 2048)
print(d.size())
# d

torch.Size([1, 16, 2048])


In [96]:
d = d[:, -1, :]
d.size()

torch.Size([1, 2048])

In [99]:
m = torch.nn.Softmax(dim=1)
input = torch.randn(2, 3)
print(input.size())
output = m(input)
print(output.size())

torch.Size([2, 3])
torch.Size([2, 3])


In [127]:
predicted = torch.rand(300, 16, 1800)
predicted.transpose(0, 1)
predicted.size()

torch.Size([300, 16, 1800])

In [106]:
predicted[1:].view(-1, predicted.size(-1)).size()

torch.Size([4784, 1800])

In [110]:
predicted[1:].view(-1, predicted.size(-1)).size()

torch.Size([4784, 1800])

In [129]:
predicted.view(-1, predicted.size(0)).size()

torch.Size([28800, 300])