# Homework №2

This homework will be dedicated to **ASR & Co**.

In general, you may implement any ASR model that was discussed in the lecture,
but we recommend to implement **QuartzNet**.

## **Important aspects (model)**
1) Pay attention on different length of utterances. P.S. **masking**.
    
2) A good ASR is a robust ASR, so we ask you to implement and use at least **4 types of augmentations** (P.S. 2 seminar).

3) Also, to get better quality, we ask you to implement a **beam search** for better decoding.

4) (Bonus) As a bonus you can use **BPE** instead of Char. You can use SentencePiece, HuggingFace or YouTokenToMe.

5) (Bonus) As a bonus you can take pretrained **LM** (or train yourself) and fusing LM with ASR.
    Way of fusing you may choose yourself.

## **Important aspects (code)**
1) You already know about pytorch-lighting (I hope :)) but you are not allowed to use it in this homework.

2) Try to write code more structurally and cleanly !

3) Good logging of experiments save your nerves and time,
    so we ask you to use **W&B** and log at least loss, WER, CER and pairs (audio -- recognized text).
    **Do not remove** the logs until we have checked your work and given you a grade!

4) We also ask you to organize your code in github repo with Docker and setup.py. You can use my template https://github.com/markovka17/dl-start-pack.

5) Your work **must be** reproducable, so fix seed, save the weights of model, and etc.

6) In the end of your work write inference utils. Anyone should be able to take your weight, load it into the model and run it on some audio track.

## Data

1) If you have enough GPU and CPU we recommend to train model on librispeech-100 (100 hours).
    If you poor student your choise is LJSpeech (24 housr) :)

1.1) LJSpeech https://keithito.com/LJ-Speech-Dataset/. Note that audio file is a single-channel 16-bit PCM WAV with a sample rate of 22050 Hz. So, feel free to resample audio in 16000 Hz.
    Target text is **Normalized Transcription** in **transcripts.csv**.

1.2) LibriSpeech https://www.openslr.org/12. Download and use train-clean-100.tar.gz.

Числа: https://drive.google.com/file/d/1HKtLLbiEk0c3l1mKz9LUXRAmKd3DvD0P/view?usp=sharing

CommonVoice Mozilla: https://commonvoice.mozilla.org/en/datasets
Он весит 50 гигабайт. В нем значительно больше коротких записей, что должно ускорить сходимость методов на нем.
Можно обучиться на нем всем или просто отщипнуть себе кусочек.

Все еще настоятельно рекомендую препроцессить датасет и выкидывать все записи длиннее N-секунд (а если вы учите не CTC, то стоит еще дополнительно выкидывать все записи длиннее  K символов), чтобы максимизировать размер батча.

##### config

https://github.com/NVIDIA/NeMo/blob/main/examples/asr/conf/quartznet_15x5.yaml

### Начинаем решение

In [1]:
import wandb

In [1]:
import pandas as pd
import string
import re

import librosa
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch import distributions
from tqdm import tqdm

from collections import Counter
from IPython import display as display_
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
BATCH_SIZE = 80
NUM_EPOCHS = 150

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

### датасет лоадер

In [4]:
class TrainDataset(torch.utils.data.Dataset):
    """Custom competition dataset."""

    def __init__(self, csv_file, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.answers = pd.read_csv(csv_file, '\t')
        self.transform = transform


    def __len__(self):
        return len(self.answers)


    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        utt_name = 'cv-corpus-5.1-2020-06-22/en/clips/' + self.answers.loc[idx, 'path']
        utt = torchaudio.load(utt_name)[0].squeeze()
        
        if len(utt.shape) != 1:
            print(utt.shape)
            print(utt)
            utt = utt[1]
            
        answer = self.answers.loc[idx, 'sentence']

        if self.transform:
            utt = self.transform(utt)

        sample = {'utt': utt, 'answer': answer}
        return sample

In [5]:
class TestDataset(torch.utils.data.Dataset):
    """Custom competition dataset."""

    def __init__(self, csv_file, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.names = pd.read_csv(csv_file, '\t')
        self.transform = transform


    def __len__(self):
        return len(self.names)


    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        utt_name = 'cv-corpus-5.1-2020-06-22/en/clips/' + self.names.loc[idx, 'path']
        utt = torchaudio.load(utt_name)[0].squeeze()
  

        if self.transform:
            utt = self.transform(utt)

        sample = {'utt': utt}
        return sample

In [6]:
def transform_tr(wav):
    aug_num = torch.randint(low=0, high=3, size=(1,)).item()
    augs = [
        lambda x: x,
        lambda x: (x + distributions.Normal(0, 0.01).sample(x.size())).clamp_(-1, 1),
        lambda x: torchaudio.transforms.Vol(.1)(x)
    ]
    
    return augs[aug_num](wav)

In [7]:
def viz(wav):
    figsize(20, 5)
    plot(wav)
    plt.show()

    display_.display(display_.Audio(wav, rate=48000, normalize=False))

In [8]:
class TextTransform:
    def __init__(self):
        self.char_dict = {}
        self.index_dict = {}
        
        self.char_dict['\''] = 0
        self.index_dict[0] = '\''
        self.char_dict[' '] = 1
        self.index_dict[1] = ' '
        for i, let in enumerate(string.ascii_lowercase):
            self.index_dict[i+2] = let
            self.char_dict[let] = i+2
            
    def text_to_int(self, text):
        labels = []
        for let in text:
            labels.append(self.char_dict[let])
        return labels
    
    def int_to_text(self, labels):
        text = []
        for num in labels:
            text.append(self.index_dict[num])
        return text

In [9]:
import math

In [10]:
def preprocess_data(data):
    text_transform = TextTransform()
    wavs = []
    input_lens = []
    labels = []
    label_lens = []
    
    for el in data:
        wavs.append(el['utt'])
        input_lens.append(math.ceil(mel_len(el['utt'].shape[0]) / 2))   ############# not yet
        label = torch.Tensor(text_transform.text_to_int(re.sub(r'[^a-z ]','', el['answer'].lower()
                                                              )
                                                       )
                            )
        labels.append(label)
        label_lens.append(len(label))
        

    wavs = pad_sequence(wavs, batch_first=True)
    labels = pad_sequence(labels, batch_first=True)
    
    return wavs, input_lens, labels, label_lens    

In [13]:
# Loading data and loaders
my_dataset = TrainDataset(csv_file='cv-corpus-5.1-2020-06-22/en/train.tsv', transform=transform_tr) ## HEYHEYYYY
print('all train+val samples:', len(my_dataset))
test_dataset = TestDataset(csv_file='cv-corpus-5.1-2020-06-22/en/test.tsv', transform=None)

all train+val samples: 435947


In [14]:
#all_lens = []
#for i, el in tqdm(enumerate(my_dataset)):
#    all_lens.append(el['utt'].shape[0])

In [15]:
#all_lens2 = np.array(all_lens)
#np.percentile(all_lens2, 95)

# 422784

#####  412416.0

In [16]:
#all_ind_lens = []

#for i, el in tqdm(enumerate(my_dataset)):
#    all_ind_lens.append([i, el['utt'].shape[0]])

In [17]:
#s = sorted(all_ind_lens, key=lambda x: x[1])

In [18]:
#with open('sorted.npy', 'wb') as f:
#    np.save(f, s)

In [19]:
with open('sorted.npy', 'rb') as f:
    s = np.load(f)

In [20]:
to_save = s[:120000][:, 0]

In [21]:
#val_ixs = to_save[::8]   # 15k

#train_ixs = []
#for i in range(len(to_save)):
#    if i % 8 != 0:
#        train_ixs.append(to_save[i])
#train_ixs = np.array(train_ixs)

In [22]:
my_dataset = torch.utils.data.Subset(my_dataset, to_save)
#len(my_dataset)

In [23]:
#my_dataset, _ = torch.utils.data.random_split(my_dataset, [50, 419277-50])
#test_dataset = my_dataset
#train_set = my_dataset
#val_set = my_dataset

In [24]:
#my_loader = DataLoader(my_dataset, batch_size=BATCH_SIZE, collate_fn=preprocess_data, 
#                       shuffle=True, drop_last=True)
train_set, val_set = torch.utils.data.random_split(my_dataset, [110000, 10000])

#train_set = torch.utils.data.Subset(my_dataset, train_ixs)
#val_set   = torch.utils.data.Subset(my_dataset, val_ixs)


train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
                          shuffle=True, collate_fn=preprocess_data, drop_last=True,
                          num_workers=0, pin_memory=True)

val_loader = DataLoader(val_set, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=preprocess_data, drop_last=True,
                        num_workers=0, pin_memory=True)



test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

In [25]:
melspec = torchaudio.transforms.MelSpectrogram(
    sample_rate=16000,            ### ? ?? ? ?? 
    n_fft=1024,
    hop_length=256,
    n_mels=64                    ### debatable for 64
).to(device)

melspec_transforms = nn.Sequential(
    torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=1024, hop_length=256,  n_mels=64),
    torchaudio.transforms.FrequencyMasking(freq_mask_param=15),
    torchaudio.transforms.TimeMasking(time_mask_param=35),
).to(device)

In [26]:
#win_len=512, hop_len=256
def mel_len(x):
    return int((x - 1024)/256) + 3

### Модель

In [12]:
import torch
import torchvision
import numpy as np
import random
import asrtoolkit

import torch.nn as nn
import torch.nn.functional as F

In [13]:
def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
set_seed(21)

In [14]:
def count_parameters(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    return sum([np.prod(p.size()) for p in model_parameters])

In [15]:
def conv_bn_act(in_size, out_size, kernel_size, stride=1, dilation=1):
    return nn.Sequential(
        nn.Conv1d(in_size, out_size, kernel_size, stride, dilation=dilation),
        nn.BatchNorm1d(out_size),
        nn.ReLU()
    )


def sepconv_bn(in_size, out_size, kernel_size, stride=1, dilation=1, padding=None):
    if padding is None:
        padding = (kernel_size-1)//2
    return nn.Sequential(
        torch.nn.Conv1d(in_size, in_size, kernel_size, 
                        stride=stride, dilation=dilation, groups=in_size,
                        padding=padding),
        torch.nn.Conv1d(in_size, out_size, kernel_size=1),
        nn.BatchNorm1d(out_size)
    )

In [16]:
class QnetBlock(nn.Module):
    def __init__(self, in_size, out_size, kernel_size, stride=1,
                R=5):
        super().__init__()
        
        self.layers = nn.ModuleList(sepconv_bn(in_size, out_size, kernel_size, stride))
        for i in range(R - 1):
            self.layers.append(nn.ReLU())
            self.layers.append(sepconv_bn(out_size, out_size, kernel_size, stride))
        self.layers = nn.Sequential(*self.layers)
        
        self.residual = nn.ModuleList()
        self.residual.append(torch.nn.Conv1d(in_size, out_size, kernel_size=1))         # requires checking
        self.residual.append(torch.nn.BatchNorm1d(out_size))
        self.residual = nn.Sequential(*self.residual)
    
    def forward(self, x):
        return F.relu(self.residual(x) + self.layers(x))

In [17]:
class QuartzNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
                    
                  #conv_bn_act(40, 256, kernel_size=33, stride=2)
        self.c1 = sepconv_bn(40, 256, kernel_size=33, stride=2)
                  
        
        self.blocks = nn.Sequential(
                #         in   out  k   s  R
                QnetBlock(256, 256, 33, 1, R=5),
                QnetBlock(256, 256, 39, 1, R=5),
                QnetBlock(256, 512, 51, 1, R=5),
                QnetBlock(512, 512, 63, 1, R=5),
                QnetBlock(512, 512, 75, 1, R=5)
        )
                  #conv_bn_act(512, 512, kernel_size=87, dilation=2)
        self.c2 = sepconv_bn(512, 512, kernel_size=87, dilation=2, padding=86)
        
        self.c3 = conv_bn_act(512, 1024, kernel_size=1)
        
        self.c4 = conv_bn_act(1024, num_classes, kernel_size=1)
        
        self.init_weights()
        
    def init_weights(self):
        pass
        
        
    def forward(self, x):
        c1 = F.relu(self.c1(x))
        blocks = self.blocks(c1)
        c2 = F.relu(self.c2(blocks))
        c3 = self.c3(c2)
        return self.c4(c3)

In [33]:
### c1 & c2 are separable!

In [34]:
def train_epoch(model, optimizer, dataloader, CTCLoss, device):
    model.train()
    
    losses = []
    
    for i, (wavs, wavs_len, answ, answ_len) in tqdm(enumerate(dataloader)):
        wavs, answ = wavs.to(device), answ.to(device)
        
        trans_wavs = torch.log(melspec_transforms(wavs) + 1e-9)   # .to(device)     # SLOW???  ## checked its CUDA
        
        optimizer.zero_grad()
            
        #print('before model', trans_wavs.shape)
        output = model(trans_wavs)
        #print('after model', output.shape)
        output = F.log_softmax(output, dim=1)                                   ### 2?
        #print('after logsoftmax', output.shape)
        output = output.transpose(0, 1).transpose(0, 2)
        #print('after transpose', output.shape)
        
        
        loss = CTCLoss(output, answ, wavs_len, answ_len)
        loss.backward()        
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 15)
        optimizer.step()
        losses.append(loss.item())
        if i % 100 == 0:
            wandb.log({'mean_train_loss':loss})
        
    return np.mean(losses)
        

        #print(i)
        #print(wavs.shape)
        #print(wavs_len)
        ##################################################3print(transformed.device)
        #print(transformed.shape)
        #if mel_len(wavs.shape[1]) != transformed.shape[2]:
        #    print("STAP THIS", mel_len(wavs.shape[1]), transformed.shape[2])
        #    break
        #plt.figure()
        #plt.title(ex[0])
        #viz(ex['utt'].squeeze_())
        #plt.show()


In [35]:
def train(model, opt, train_dl, scheduler, CTCLoss, device, n_epochs, val_dl=None):
    
    
    
    for epoch in range(n_epochs):
        print("Epoch {} of {}".format(epoch, n_epochs), 'LR', scheduler.get_last_lr())
        
        mean_loss = train_epoch(model, opt, train_dl, CTCLoss, device)
        print('MEAN EPOCH LOSS IS', mean_loss)
        
        scheduler.step()
        
        if (val_dl != None):
            test(model, opt, val_dl, CTCLoss, device) 

In [36]:
def decoder_func(output, answ, answ_lens, blank_label=0, del_repeated=True):
    # output : [B, freq, porbab(28?)]
    # answ   : []
    
    decoded_preds = []
    decoded_targs  = []
    
    text_transform = TextTransform()

    # batch_freqs : [B, freq]
    batch_freqs = torch.argmax(output, dim=2).transpose(0, 1)
    
    for i, freqs in enumerate(batch_freqs):
        # freqs : [freq]
        
        preds = []
        
        decoded_targs.append(
            text_transform.int_to_text(answ[i][:answ_lens[i]].tolist())   ####  не может быть другая длина?  
        )
        
        for j, num in enumerate(freqs):
            if num != blank_label:
                if del_repeated and j != 0 and num == freqs[j-1]:
                    continue
            preds.append(num.item())
        decoded_preds.append(text_transform.int_to_text(preds))
    
    return decoded_preds, decoded_targs    

In [60]:
def cer(target, pred):
    cer_res = asrtoolkit.cer(''.join(target), ''.join(pred))
    
    wandb.log({"CER": cer_res})
    print('CER', cer_res)
    
    print('target', ''.join(target))
    print('prediction', ''.join(pred))
    
    return cer_res

def wer(target, pred):
    wer_res = asrtoolkit.wer(''.join(target), ''.join(pred))
    wandb.log({"WER": cer_res})
    return cer_res

In [38]:
def test(model, optimizer, dataloader, CTCLoss, device):
    model.eval()
    
    cers, wers = [], []
    losses = []
    
    with torch.no_grad():
        for i, (wavs, wavs_len, answ, answ_len) in enumerate(dataloader):
            wavs, answ = wavs.to(device), answ.to(device)

            trans_wavs = torch.log(melspec(wavs) + 1e-9)     # SLOW???

            output = model(trans_wavs)
            output = F.log_softmax(output, dim=1)                        #?  2 ??
            output = output.transpose(0, 1).transpose(0, 2)
            
            loss = CTCLoss(output, answ, wavs_len, answ_len)
            losses.append(loss.item())
            
            # argmax / beam_search
            preds, targets = decoder_func(output, answ, answ_len)
            for i in range(len(preds)):
                if i == 0:
                    cers.append(cer(targets[i], preds[i]))
                wers.append(wer(targets[i], preds[i]))
                
        avg_cer = np.mean(cers)
        avg_wer = np.mean(wers)
        avg_loss= np.mean(losses)
        print('average test loss is', avg_loss)
        wandb.log({'mean_VAL_loss':avg_loss})


In [39]:
wandb.login()
wandb.init()
train_table = wandb.Table(columns=["Predicted Text", "True Text"])

[34m[1mwandb[0m: Currently logged in as: [33mkirili4ik[0m (use `wandb login --relogin` to force relogin)


In [40]:
model = QuartzNet(28)
print(count_parameters(model))
model.to(device)
wandb.watch(model)

6729892


[<wandb.wandb_torch.TorchGraph at 0x7f0ae637ec50>]

In [41]:
#import torch_optimizer
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import StepLR

opt = torch_optimizer.NovoGrad(
                        model.parameters(),
                        lr=0.01,
                        betas=(0.8, 0.7),
                        weight_decay=0.001,
) # this for bs 32 per GPU

#opt = torch.optim.RMSprop(model.parameters(), weight_decay=0.0001)
#scheduler = StepLR(opt, step_size=2, gamma=0.97) 

scheduler  = CosineAnnealingLR(opt, T_max=50, eta_min=0, last_epoch=-1) # ###### TMAX = MAX NUM OF EPOCHS

In [42]:
CTCLoss = nn.CTCLoss(blank=0).to(device)

In [43]:
train(model, opt, train_loader, scheduler, CTCLoss, device,
     n_epochs=NUM_EPOCHS, val_dl=val_loader)

0it [00:00, ?it/s]

Epoch 0 of 150 LR [0.01]


403it [25:14,  3.75s/it]

torch.Size([2, 205056])
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -2.7344e-06,
         -7.1339e-06, -7.7367e-05],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -2.7344e-06,
         -7.1339e-06, -7.7367e-05]])


513it [32:05,  3.73s/it]

torch.Size([2, 200448])
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -7.8108e-05,
         -1.9141e-05, -7.8119e-06],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -7.8108e-05,
         -1.9141e-05, -7.8119e-06]])


632it [39:29,  3.71s/it]

torch.Size([2, 270720])
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -7.7933e-06,
         -1.6287e-05, -5.2683e-05],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -7.7933e-06,
         -1.6287e-05, -5.2683e-05]])


970it [1:00:21,  3.69s/it]

torch.Size([2, 180864])
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -4.0159e-05,
         -3.1214e-05,  3.6880e-06],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -4.0159e-05,
         -3.1214e-05,  3.6880e-06]])


1163it [1:12:09,  3.66s/it]

torch.Size([2, 200448])
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -8.9087e-05,
         -5.9199e-05, -5.9370e-05],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -8.9087e-05,
         -5.9199e-05, -5.9370e-05]])


1187it [1:13:37,  3.66s/it]

torch.Size([2, 198144])
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 4.8324e-05, 7.6901e-05,
         3.2309e-05],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 4.8324e-05, 7.6901e-05,
         3.2309e-05]])


1206it [1:14:47,  3.67s/it]

torch.Size([2, 131328])
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0004, 0.0005, 0.0005],
        [0.0000, 0.0000, 0.0000,  ..., 0.0004, 0.0005, 0.0005]])


1375it [1:25:07,  3.71s/it]


MEAN EPOCH LOSS IS 3.4354329833984374
<wandb.data_types.Table object at 0x7f0afcb21950>
target belturbet railway station is open as a railway museum
prediction '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
<wandb.data_types.Table object at 0x7f0afcb21950>
target her father was a grocer and later a police officer
prediction '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

<wandb.data_types.Table object at 0x7f0afcb21950>
target their coloration is red pink and silvery
prediction ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
<wandb.data_types.Table object at 0x7f0afcb21950>
target a childrens playground is provided
prediction ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

<wandb.data_types.Table object at 0x7f0afcb21950>
target he is an activist for various political causes
prediction ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
<wandb.data_types.Table object at 0x7f0afcb21950>
target king asserts it is the only slayer song on the album
prediction '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

<wandb.data_types.Table object at 0x7f0afcb21950>
target the book is a critical first hand account of the criminal justice system
prediction '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
<wandb.data_types.Table object at 0x7f0afcb21950>
target after the war he married india thelma walker
prediction '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

<wandb.data_types.Table object at 0x7f0afcb21950>
target these old locks can still be seen near nunda
prediction '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
<wandb.data_types.Table object at 0x7f0afcb21950>
target the online community is friendly and helpful
prediction ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

<wandb.data_types.Table object at 0x7f0afcb21950>
target all three were later fired
prediction '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
<wandb.data_types.Table object at 0x7f0afcb21950>
target problems in the agricultural sector have fueled urbanization
prediction ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

<wandb.data_types.Table object at 0x7f0afcb21950>
target i was looking for your father
prediction '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
<wandb.data_types.Table object at 0x7f0afcb21950>
target i shall see her again
prediction ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

<wandb.data_types.Table object at 0x7f0afcb21950>
target the star is billions of miles away
prediction '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
<wandb.data_types.Table object at 0x7f0afcb21950>
target it is the county seat of ida county
prediction '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''

0it [00:00, ?it/s]

<wandb.data_types.Table object at 0x7f0afcb21950>
target blue weaver identified the musicians as shown from memory
prediction '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
average test loss is 3.111426202774048
Epoch 1 of 150 LR [0.01]


99it [06:02,  3.68s/it]

torch.Size([2, 200448])
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -7.8108e-05,
         -1.9141e-05, -7.8119e-06],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -7.8108e-05,
         -1.9141e-05, -7.8119e-06]])


135it [08:18,  3.70s/it]


KeyboardInterrupt: 