# Training notebook for Huggingface XLSR-Wav2Vec2 Model
## (https://huggingface.co/transformers/model_doc/wav2vec2.html)
## Steps
1. Data preprocessing and preparation
2. Dataset class and dataloader
3. Model preparation
4. training Loop
### Note old script based on (https://huggingface.co/blog/fine-tune-xlsr-wav2vec2)

# Step 1 Dataprocessing
- Data is assumed to be in csv format with columns (path, sentence)
- we have have from different sources (listed below)
- some preprocessing has already been done for data from different sources

In [2]:
from datasets import load_dataset, load_metric
import pandas as pd

#Set all sources of data
commonvoice = "data/commonvoice/train.csv"
singlespeaker = "data/singlespeaker/train.csv"
speechcollector = "data/speechcollector/train.csv"
voxpopuli = "data/fi/train.csv"
eduskunta_1 = "data/eduskunnanpuheet/uudetpuheet/dev-eval/train.csv"
eduskunta_2 = "data/eduskunnanpuheet/uudetpuheet/2008-2016set/train.csv"

test1 = "data/commonvoice/test.csv"
test2 = "data/eduskunnanpuheet/uudetpuheet/dev-eval/test.csv"

train_df = pd.concat([pd.read_csv(commonvoice), pd.read_csv(singlespeaker), pd.read_csv(speechcollector), pd.read_csv(voxpopuli), pd.read_csv(eduskunta_1), pd.read_csv(eduskunta_2)])
test_df = pd.concat([pd.read_csv(test1), pd.read_csv(test2)])

print(f"Training set contains {len(train_df)} Samples")
print(f"test set contains {len(test_df)} Samples")
train_df.head()

Training set contains 106443 Samples
test set contains 1976 Samples


Unnamed: 0,path,sentence
0,/home/sampo/.cache/huggingface/datasets/downlo...,Mitä nyt tekisimme?
1,/home/sampo/.cache/huggingface/datasets/downlo...,Äänestämme tämän vuoksi toisin kuin maataloude...
2,/home/sampo/.cache/huggingface/datasets/downlo...,"Rupeatko remmiin, vai et?"
3,/home/sampo/.cache/huggingface/datasets/downlo...,Äänestin näin ollen mietinnön puolesta.
4,/home/sampo/.cache/huggingface/datasets/downlo...,"Kiitos, että tulitte ja opetitte meille viisau..."


# Remove these specific characters and lower case transcriptions

In [2]:
import random
import pandas as pd
from IPython.display import display, HTML
import re
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\...\…\–\é]'

def custom_remove_special_characters(sent):
    sent = re.sub(chars_to_ignore_regex, '', sent).lower() + " "
    return sent

train_df['sentence'] = train_df['sentence'].apply(custom_remove_special_characters)
test_df['sentence'] = test_df['sentence'].apply(custom_remove_special_characters)
train_df.head()

Unnamed: 0,path,sentence
0,/home/sampo/.cache/huggingface/datasets/downlo...,mitä nyt tekisimme
1,/home/sampo/.cache/huggingface/datasets/downlo...,äänestämme tämän vuoksi toisin kuin maataloude...
2,/home/sampo/.cache/huggingface/datasets/downlo...,rupeatko remmiin vai et
3,/home/sampo/.cache/huggingface/datasets/downlo...,äänestin näin ollen mietinnön puolesta
4,/home/sampo/.cache/huggingface/datasets/downlo...,kiitos että tulitte ja opetitte meille viisaut...


# Create vocabulary of characters in the dataset
- (if there are characters you dont want revise the regex in the previous step)

In [3]:
import itertools

def get_chars(df):
    return set(itertools.chain(*[list(x) for x in df['sentence'].values]))

vocab_list = list(get_chars(train_df).union(get_chars(test_df)))
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
print(vocab_dict)

{'q': 0, 'd': 1, 'g': 2, 'y': 3, 'l': 4, 'e': 5, 'j': 6, 'm': 7, 'a': 8, 'p': 9, 'f': 10, 'v': 11, 'h': 12, 'ä': 13, 't': 14, ' ': 15, 'n': 16, 'k': 17, 'r': 18, 'ö': 19, 'u': 20, 'z': 21, 's': 22, 'c': 23, 'b': 24, 'w': 25, 'i': 26, 'å': 27, 'x': 28, 'o': 29}


# Add special tokens into the vocab and save

In [4]:
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)

#for key in vocab_dict.keys():
#    if key != "[PAD]":
#        vocab_dict[key] +=1

import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

# Create Hugginface Processor from vocab
- Notice that voxpopuli model assumes clips are sampled at 16000Hz
- used for preprocess, encode and decode inputs

In [5]:
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor

tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

# PyTorch Dataset class and Dataloader
- default mode loads and resamples audio files on the fly to save RAM
- loading on the fly does not slow training much
- training samples are sorted according to transcription length to reduce infinities on CTC Loss
- If you dont sort the samples remember to change model flag ctc_zero_infinity to True
- collate function handles padding and batching
- audio files are very memory intensive peak VRAM comsumtion with batch_size = 4 is 18GB

In [6]:
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchaudio
import librosa

def resample(audio, source_sr, target_sr = 16000):
    audio = librosa.resample(np.asarray(audio), source_sr, target_sr)
    return audio


class CTCDataset(Dataset):
    """
    Dataset class used for Speech recognition with ctc loss
    enables precomputing data as arrays or transforming on the fly
    if dataset does not fit into ram
    """
    def __init__(self, dataframe, processor, mode="otf"):
        
        self.data = dataframe
        self.data.sort_values(by="sentence", key=lambda x: x.str.len(), inplace=True, ascending=False)
        self.processor = processor
        self.mode = mode
        if mode!="otf":
            raise NotImplemented
    
    def _processaudio(self, path):
        data, sr = torchaudio.load(path)
        data = data[0].numpy()
        data = resample(data, sr, 16000)
        
        return data
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):     
        if self.mode == 'otf':
            sent = self.data.iloc[idx, 1]
            data = self._processaudio(self.data.iloc[idx, 0])
            return data, sent
        
    def _precompute(self):
        pass
    
    def reorder_df(self):
        pass
        
    
def collate_fn_otf_train(batch):
    """
    collate function used for training and loading audio data on the fly
    """
    
    lists = list(zip(*batch))
    inputs = processor(lists[0], sampling_rate=16_000, return_tensors="pt", padding=True, pad_to_multiple_of=8)
    with processor.as_target_processor():
        labels = processor(lists[1], padding=True, return_tensors="pt", pad_to_multiple_of=8).input_ids
    return inputs.input_values, inputs.attention_mask, labels

def collate_fn_otf(batch):
    """
    collate function used for training and loading audio data on the fly
    """
    
    lists = list(zip(*batch))
    inputs = processor(lists[0], sampling_rate=16_000, return_tensors="pt", padding=True)
    with processor.as_target_processor():
        labels = processor(lists[1], padding=True, return_tensors="pt").input_ids
    return inputs.input_values, inputs.attention_mask, labels



trainset = CTCDataset(train_df, processor)
testset = CTCDataset(test_df, processor)

trainloader = DataLoader(trainset, batch_size = 4, collate_fn = collate_fn_otf_train, num_workers=8)
testloader = DataLoader(testset, batch_size=1, collate_fn = collate_fn_otf, num_workers=4)


# Load pretrained model from huggingface
- currently using voxpopuli (https://github.com/facebookresearch/voxpopuli)

In [7]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-100k-voxpopuli",
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.1,
    gradient_checkpointing=True,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
    ctc_zero_infinity=False
)

#Freeze the weights of the pretrained feature extractor
model.freeze_feature_extractor()

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-100k-voxpopuli and are newly initialized: ['lm_head.weight', 'lm_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# see documentation of model to set parameters

In [8]:
help(Wav2Vec2ForCTC)

Help on class Wav2Vec2ForCTC in module transformers.models.wav2vec2.modeling_wav2vec2:

class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel)
 |  Wav2Vec2ForCTC(config)
 |  
 |  Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). 
 |  Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
 |  <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
 |  
 |  This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
 |  methods the library implements for all its model (such as downloading or saving etc.).
 |  
 |  This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use
 |  it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
 |  behavior.
 |  
 |  Parameters:
 |      config

# Define utility functions for training

In [9]:
from datasets import load_metric
from tqdm.notebook import tqdm


def decode_output(logits):
    pred_ids = torch.argmax(logits, dim=-1)
    pred = processor.batch_decode(pred_ids)
    return pred[0]

@torch.no_grad()
def evaluation_func(model, dataloader, ref_sentences ,use_amp, device="cuda"):
    """
    return tuple (loss, wer)
    """
    wer = load_metric("wer")
    
    model.eval()
    preds_amp = []
    preds = []
    losses_amp = []
    losses = []
        
    for batch in tqdm(dataloader):
        inputs, masks, labels = batch
        
        output = model(inputs.to(device), masks.to(device), labels = labels.to(device))
        loss = output.loss.item()
        
        with torch.cuda.amp.autocast(enabled=use_amp):
            output_amp = model(inputs.to(device), masks.to(device), labels = labels.to(device))
            loss_amp = output_amp.loss.item()
            
        losses.append(loss)
        losses_amp.append(loss_amp)
        
        logits_amp = output_amp.logits
        logits = output.logits
        
        pred = decode_output(logits)
        pred_amp = decode_output(logits_amp)
        
        preds.append(pred)
        preds_amp.append(pred_amp)
        
    return sum(losses)/len(losses), sum(losses_amp)/len(losses_amp),wer.compute(predictions=preds, references=ref_sentences) ,wer.compute(predictions=preds_amp, references=ref_sentences) 
    
def checkpoint_func(model, save_dir):
    model.save_pretrained(save_directory=save_dir)
    return


#evaluation_func(model, testloader, testset.data.sentence)
#checkpoint_func(model, "testi/")

# set up parameters and training loop
- losses, step_size and WER logged into tensorboard
- note WER evaluation wer is obtained without language model
- one epoch takes considerably less than tqdm estimates at first because samples are sorted from longest to shortest

In [12]:
import transformers
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

save_dir = "test_run1"
device = "cuda"
losses = []
training_losses = []

model.to(device)

use_amp  = True
num_epochs = 20
lr = 0.00025
step_interval = 2
eval_interval = len(trainloader)-1
steps = 0

#setup optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
#scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=200, num_training_steps=len(trainloader)*num_epochs/step_interval)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 2, gamma=0.5)


scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

best_wer = 1.0
eval_losses = []
eval_wers = []
eval_step = 0

print("starting training loop")
for epoch in range(num_epochs):
    print(f"starting epoch: {epoch+1}")
    model.train()
    losses = []
    i = 0
    for batch in tqdm(trainloader):
        
        inputs, masks, labels = batch
        with torch.cuda.amp.autocast(enabled=use_amp):
            output = model(inputs.to(device), masks.to(device), labels=labels.to(device))
            loss = output.loss/step_interval
        
        scaler.scale(loss).backward()
        #accumulate gradients for step_interval batches
        if (i+1)%step_interval == 0:
            #optimizer.step()
            scaler.step(optimizer)
            scaler.update()
            #scheduler.step()
            optimizer.zero_grad()
            steps+=1
        losses.append(output.loss.item())
        #if i%30==0:
        #    print(output.loss.item())
        #evaluate model and save best WER
        if (i+1)%eval_interval == 0:
            eval_loss,eval_loss_amp, wer, wer_amp = evaluation_func(model, testloader, testset.data.sentence, use_amp, device)
            
            writer.add_scalar('eval/loss', eval_loss, eval_step)
            writer.add_scalar('eval/wer', wer, eval_step)
            writer.add_scalar('eval/loss_amp', eval_loss_amp, eval_step)
            writer.add_scalar('eval/wer_amp', wer_amp, eval_step)
            writer.add_scalar('lr', scheduler.get_last_lr()[0], eval_step)
            eval_losses.append(eval_loss)
            eval_wers.append(wer)
            eval_step +=1
            if wer < best_wer:
                #save model with best test WER
                checkpoint_func(model, save_dir)
                best_wer = wer
            else:
                scheduler.step()
            model.train()
            
        i+=1
    #end of epoch
    epoch_loss = sum(losses)/len(losses)
    training_losses.append(epoch_loss)
    writer.add_scalar('train/epoch_loss', epoch_loss, epoch)
#check for final improvements
eval_loss,eval_loss_amp, wer, wer_amp = evaluation_func(model, testloader, testset.data.sentence, use_amp, device)
if wer < best_wer:
    best_wer = wer
    checkpoint_func(model, save_dir)
print(f"training finished, best WER: {best_wer}")
writer.close()

starting training loop
starting epoch: 1


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

5.858843803405762
7.797380447387695
7.684926986694336
5.518434047698975
3.99419903755188
6.219414710998535
3.1812806129455566
2.941330909729004
2.9052228927612305
2.807572364807129
2.8129589557647705
2.8369622230529785
2.8283865451812744
2.72320556640625


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


2.839996576309204
2.8279201984405518
2.6942644119262695
2.8068947792053223
2.7283594608306885
2.812673568725586
2.7107510566711426
2.85324764251709
2.745150566101074
2.849102020263672
2.8087666034698486
2.729599952697754
2.8416390419006348


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


2.756880521774292
2.6971254348754883
2.878608226776123
2.8405351638793945
2.6057515144348145
2.7973480224609375
2.7345035076141357
2.462458610534668
2.3366775512695312
2.442535161972046
2.3021788597106934
1.9663691520690918
1.9234588146209717


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


1.6248788833618164
1.1326515674591064
0.8620301485061646
0.8041617274284363
0.7578226327896118
1.3452996015548706
0.44853538274765015
0.7520885467529297
0.5265997648239136
0.4469468593597412
1.4729619026184082
0.26274579763412476
0.8705237507820129
0.4823071360588074


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.37300631403923035
0.33751505613327026
0.5621663928031921
0.3844975531101227
0.26316505670547485
0.5790499448776245
0.38351261615753174
0.48735225200653076
0.5135809183120728
0.06608554720878601
0.7848414182662964
0.42876696586608887
0.2893849313259125


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.28311702609062195
0.32164469361305237
0.8614224195480347
0.5570911169052124
0.1970493197441101
0.11819645017385483
0.22873561084270477
0.10539290308952332
0.088357113301754
0.3873504400253296
1.079738736152649
0.5117307901382446
0.16830945014953613


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.2976078987121582
-0.005096446722745895
1.0851731300354004
0.22593514621257782
0.26899614930152893
0.35911881923675537
0.5063778758049011
0.25335928797721863
0.3625792860984802
0.1751619577407837
0.35948511958122253
0.2508024275302887
0.17701733112335205
0.3362533450126648


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.033439185470342636
0.5304275751113892
0.04406324401497841
0.03019705042243004
0.19741946458816528
0.19534969329833984
0.2490214705467224
0.06075934320688248
0.34747496247291565
0.3229009509086609
0.31580156087875366
0.286812961101532
0.09938842803239822


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.6139000654220581
0.6785433292388916
0.017092019319534302
0.16681799292564392
0.08154319226741791
0.4741624593734741
0.02652619406580925
0.2436477541923523
0.49128562211990356
0.3840036392211914
0.13823625445365906
0.042942121624946594
0.08984225988388062


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.17854489386081696
0.14755050837993622
-0.04437124729156494
-0.1936655193567276
0.3277369737625122
0.09485995769500732
0.22184817492961884
-0.22911027073860168
0.9604953527450562
-0.08953160047531128
-0.026035301387310028
0.5252953767776489
-0.7556685209274292

starting epoch: 2


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

4.463578701019287
0.7974756956100464
0.5785737633705139
0.7949084043502808
0.5513147711753845
0.48526066541671753
0.6007698178291321
0.395844429731369
0.4300820827484131
0.3652011752128601
0.3906536102294922
0.3047419488430023
0.4296000599861145
0.3328893184661865


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.4379192590713501
0.6960424184799194
0.337399959564209
0.635111927986145
0.4752110242843628
0.3220958709716797
0.7231906652450562
0.3484245538711548
0.5417481064796448
0.734129011631012
0.4156854748725891
0.7457730770111084
0.28085434436798096


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.04770369827747345
0.09561123698949814
0.31463778018951416
0.3386758267879486
0.2614305913448334
0.19307896494865417
0.29901161789894104
0.08643928915262222
0.25783151388168335
0.11621855944395065
0.5012878775596619
0.23642924427986145
0.32931679487228394


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.15241298079490662
0.18337620794773102
0.29295411705970764
0.11346712708473206
0.12327247858047485
0.7316601872444153
0.10536512732505798
0.2613980770111084
0.11090560257434845
0.0008386671543121338
1.2382450103759766
0.21702586114406586
0.5809657573699951
0.3288195729255676


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.0803970992565155
0.07526377588510513
0.20781123638153076
0.116004578769207
0.06813737750053406
0.061020832508802414
0.12772820889949799
0.21385297179222107
0.3804463744163513
0.04496058076620102
0.11615493148565292
0.22240641713142395
0.06994225084781647


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.1935426890850067
0.13154566287994385
0.4647471308708191
0.46589353680610657
0.035117801278829575
0.023818952962756157
0.12139834463596344
-0.06865687668323517
0.025846097618341446
0.035953618586063385
0.8304511904716492
0.1886911541223526
0.26246631145477295


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.09371134638786316
-0.0959308072924614
0.8856195211410522
0.05906844139099121
0.1064065471291542
0.11203932762145996
0.3489956259727478
0.2765127718448639
0.3168354332447052
0.027663011103868484
0.5349891781806946
0.18768607079982758
-0.06016920879483223
0.09632454067468643


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


-0.09462355077266693
0.367374986410141
0.005201985128223896
-0.04610095173120499
-0.023119695484638214
0.21442186832427979
0.0814446285367012
-0.02544589899480343
0.0026928894221782684
0.19394344091415405
0.1237117201089859
0.061651237308979034
0.08715187013149261


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.28652650117874146
0.47144076228141785
-0.09913342446088791
0.0749564915895462
-0.01745346561074257
0.12572801113128662
-0.06400160491466522
0.09739304333925247
0.7196926474571228
0.23354841768741608
-0.08882185816764832
-0.3095310926437378
-0.0447121262550354


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.07147899270057678
-0.1233869418501854
-0.08997699618339539
-0.3117680549621582
0.33026403188705444
-0.05776464194059372
-0.029502667486667633
-0.17123019695281982
0.49885955452919006
0.1822509616613388
-0.19454966485500336
0.2665042579174042
-0.6769394874572754

starting epoch: 3


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

4.341226100921631
0.3329543471336365
0.5058911442756653
0.8592271208763123
0.44751209020614624
0.3637950122356415
0.544418454170227
0.40817731618881226
0.3218905031681061
0.18211598694324493
0.24452084302902222
0.2516063153743744
0.39313507080078125
0.2349202036857605


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.5120618343353271
0.569783091545105
0.29691770672798157
0.6393551826477051
0.38550078868865967
0.27971482276916504
0.4077602028846741
0.3243488371372223
0.3958481252193451
0.3101550042629242
0.2944333553314209
0.4891981780529022
0.2138185054063797


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.04478813707828522
0.06801337003707886
0.18492910265922546
0.15822744369506836
0.1596398502588272
0.1396532654762268
0.2016482800245285
0.1171778067946434
0.09110262244939804
0.08345123380422592
0.2601882219314575
0.19740042090415955
0.39319393038749695


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.15797185897827148
0.13781388103961945
0.14089590311050415
0.05759163573384285
0.0291525237262249
0.8755197525024414
0.16167300939559937
0.19624163210391998
0.07407920807600021
-0.014062155038118362
1.298454999923706
0.11697643995285034
0.5193305015563965
0.34838056564331055


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.06348402798175812
0.028690796345472336
0.252654492855072
0.10950435698032379
0.030270855873823166
0.08020646870136261
0.10802404582500458
0.23146896064281464
0.08911880850791931
-0.004114143550395966
0.008193429559469223
0.13388068974018097
-0.023441404104232788


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.008387070149183273
0.09456557035446167
0.24719849228858948
0.4302612245082855
0.08554226160049438
-0.020909827202558517
0.029086127877235413
-0.12420515716075897
0.008802436292171478
0.005406521260738373
0.7412952184677124
0.1767524927854538
0.5372858047485352


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.030456025153398514
-0.09675204753875732
0.634823203086853
-0.005975597072392702
0.06365745514631271
0.13061100244522095
0.31521502137184143
0.13067740201950073
0.15508638322353363
0.042596485465765
0.18199685215950012
0.07094181329011917
-0.07606494426727295
0.036922018975019455


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


-0.10349193215370178
0.3544386327266693
0.0010245665907859802
-0.07416890561580658
0.030533283948898315
0.13731536269187927
-0.03205453231930733
-0.12054777890443802
-0.005118196830153465
0.14357157051563263
0.1626220941543579
-0.008542388677597046
-0.02050788700580597


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.21361134946346283
0.3967559337615967
-0.1926906406879425
0.03334175422787666
0.09501834958791733
0.022963259369134903
-0.17553313076496124
0.051668085157871246
0.6520901322364807
0.23078390955924988
0.37393125891685486
-0.055378176271915436
-0.13589945435523987


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


-0.021098362281918526
-0.19357876479625702
-0.20666567981243134
-0.3445570468902588
0.12049897015094757
-0.18464922904968262
-0.10903272032737732
-0.3261335492134094
0.4529375433921814
-0.07372689247131348
-0.27987971901893616
0.41000640392303467
-0.7271722555160522

starting epoch: 4


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

3.7840194702148438
0.40866243839263916
0.4735366702079773
0.6063088178634644
0.44328123331069946
0.2839619815349579
0.45071789622306824
0.3521359860897064
0.2739472985267639
0.35686102509498596
0.23354151844978333
0.22208616137504578
0.30662497878074646
0.2273050844669342


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.3218533992767334
0.5415292382240295
0.23299181461334229
0.5289821028709412
0.19539454579353333
0.2593145966529846
0.32639795541763306
0.2512257397174835
0.37997639179229736
0.24699057638645172
0.25717946887016296
0.436576783657074
0.1972925364971161


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.03668845444917679
-0.0012203482910990715
0.23668408393859863
0.2455134093761444
0.13063155114650726
0.19407477974891663
0.19647222757339478
0.03547557070851326
0.17992377281188965
0.07281790673732758
0.3892126679420471
0.06735792011022568
0.3216741979122162


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.11855319142341614
0.11840103566646576
0.14253517985343933
0.04376853257417679
0.006720326840877533
0.6762596368789673
0.05566845089197159
0.16281458735466003
0.029498694464564323
-0.039070628583431244
1.0750988721847534
0.12366661429405212
0.44593316316604614
0.21026092767715454


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.007299443241208792
0.01617482118308544
0.08364815264940262
0.006071174517273903
-0.023493852466344833
-0.01153610646724701
0.06642630696296692
0.16100534796714783
-0.028029652312397957
-0.05290193110704422
-0.034741565585136414
-0.046741291880607605
-0.02055615931749344


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.037960465997457504
0.07217749953269958
0.2683800756931305
0.4935219883918762
0.09896230697631836
-0.07975046336650848
-0.04016273468732834
-0.14347490668296814
-0.08688405156135559
-0.013337392359972
0.6775747537612915
0.15554974973201752
0.08734817057847977


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


-0.0017844266258180141
-0.08638279885053635
0.6042418479919434
-0.019573885947465897
0.10840769112110138
0.17810676991939545
0.1593736708164215
0.14569292962551117
0.14549973607063293
0.013255288824439049
0.17691388726234436
-0.004390226677060127
-0.07739529013633728
0.02076072245836258


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


-0.14644239842891693
0.2916334867477417
-0.01127447746694088
-0.08968620002269745
-0.08029960095882416
-0.052259933203458786
-0.0827401876449585
-0.11815597116947174
-0.07331215590238571
0.07494017481803894
0.1553768664598465
-0.01924402453005314
-0.015947021543979645


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.29725420475006104
0.4000128507614136
-0.196588397026062
0.13216234743595123
0.06605841219425201
0.0007386207580566406
-0.19108273088932037
0.026141762733459473
0.5903323292732239
0.03835872560739517
-0.15418915450572968
-0.3468564450740814
-0.2983661890029907


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.02300501987338066
-0.1934470534324646
-0.12478728592395782
-0.36340898275375366
0.2507326304912567
-0.17093276977539062
-0.18124054372310638
-0.324947714805603
0.21800082921981812
-0.057038355618715286
-0.3901233673095703
0.18480654060840607
-0.786037802696228

starting epoch: 5


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

3.395339012145996
0.2624753415584564
0.6878087520599365
0.6204136610031128
0.3472158908843994
0.37228161096572876
0.40688449144363403
0.3542187511920929
0.26912134885787964
0.155037522315979
0.2819618582725525
0.27437183260917664
0.35570457577705383
0.21757854521274567


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.27323728799819946
0.6242572069168091
0.45821425318717957
0.5911746621131897
0.21580444276332855
0.23140951991081238
0.3503333032131195
0.26239970326423645
0.3306632936000824
0.23871853947639465
0.26842957735061646
0.33001476526260376
0.15226605534553528


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


-0.012842083349823952
0.012110253795981407
0.13007596135139465
0.0707954615354538
0.1309007853269577
0.12489241361618042
0.22956469655036926
0.037209488451480865
0.052389997988939285
0.08838747441768646
0.18500202894210815
0.08141198009252548
0.2217782586812973


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.056412238627672195
0.07969551533460617
0.1372087150812149
-0.01449424959719181
-0.03399727866053581
0.4279378652572632
0.1238156333565712
0.2242102473974228
0.03659976273775101
-0.05036604031920433
0.9693037867546082
0.05700056627392769
0.5164126753807068
0.1542254239320755


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.023348046466708183
0.003522958606481552
0.12845423817634583
0.016978131607174873
-0.0739968791604042
0.008998844772577286
0.06244003400206566
0.1670634150505066
-0.023265615105628967
-0.012292356230318546
-0.05433547496795654
-0.0010405052453279495
-0.06765381991863251


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


-0.033617109060287476
0.08394694328308105
0.26463547348976135
0.44814780354499817
0.05446327477693558
-0.08711986243724823
0.09419150650501251
-0.07229366153478622
-0.004387367516756058
-0.04870207607746124
0.6637864112854004
0.2060806155204773
0.04506565257906914


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.02342161163687706
-0.13571125268936157
0.7089574933052063
-0.019546769559383392
0.021151460707187653
0.08792002499103546
0.3337456285953522
0.2597320079803467
0.11769406497478485
-0.007912492379546165
0.06627896428108215
-0.054046422243118286
-0.0727754533290863
0.05752826854586601


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


-0.16221968829631805
0.3607677221298218
-0.03246752545237541
-0.10079409182071686
-0.11302365362644196
0.007994605228304863
-0.10538764297962189
-0.09393350034952164
0.07230165600776672
0.11714273691177368
0.132985457777977
-0.03370466083288193
-0.0726371631026268


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.19099172949790955
0.3900912404060364
-0.2755090594291687
-0.03987674415111542
-0.005627710372209549
-0.1620061695575714
-0.24159780144691467
0.13630887866020203
0.29810258746147156
0.11735723912715912
-0.08237837255001068
-0.3289172649383545
-0.38638100028038025


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.2113272100687027
-0.05737638100981712
-0.1967201828956604
-0.3426438570022583
0.10933391004800797
-0.23631629347801208
-0.20163115859031677
-0.31007757782936096
0.08817136287689209
-0.17125332355499268
-0.22675904631614685
0.010959312319755554
-1.0175507068634033

starting epoch: 6


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

3.7683286666870117
0.3299078345298767
0.5244357585906982
0.5430790781974792
0.4854309558868408
0.21502608060836792
0.43145623803138733
0.33583900332450867
0.24035930633544922
0.16511982679367065
0.16671311855316162
0.2791154384613037
0.2473793923854828
0.1839239001274109


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.28545331954956055
0.4104269742965698
0.18842804431915283
0.39057600498199463
0.143044114112854
0.20934084057807922
0.38995689153671265
0.24745354056358337
0.4074842929840088
0.18634286522865295
0.24902129173278809
0.5338796377182007
0.24580591917037964


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.011178453452885151
-0.02633170410990715
0.12751440703868866
0.20606863498687744
0.137472465634346
0.10759852081537247
0.2438477873802185
-0.0088344756513834
0.2617359459400177
0.03125584125518799
0.26721835136413574
0.0706082358956337
0.17484432458877563


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.05418563261628151
0.2400396466255188
0.05304180830717087
0.007907439023256302
0.05313754826784134
0.42493245005607605
0.05374058336019516
0.19290795922279358
0.04651476442813873
0.020687878131866455
0.8274529576301575
0.08637114614248276
0.44929763674736023
0.21238821744918823


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.004966230131685734
0.06191670894622803
0.04732256010174751
0.03448878973722458
-0.07600707560777664
-0.021979205310344696
0.16721287369728088
0.1284838616847992
0.003733256831765175
-0.052455611526966095
-0.08188135176897049
-0.03464280068874359
-0.09735697507858276


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


-0.08475349098443985
0.03881661593914032
0.17363499104976654
0.5001270771026611
-0.0036881230771541595
-0.04864472150802612
-0.0449189655482769
-0.1641010046005249
-0.08514808118343353
-0.028651393949985504
0.6279768347740173
0.07466156035661697
0.06184792518615723


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


-0.008060194551944733
-0.035180073231458664
0.5484451055526733
-0.06946411728858948
0.05590490996837616
-0.07014074921607971
0.1851055920124054
0.08206672221422195
0.07094443589448929
0.07306542992591858
0.03817904740571976
-0.054522544145584106
-0.10293753445148468
-0.016579994931817055


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


-0.1828804612159729
0.16141852736473083
-0.011149175465106964
-0.058126937597990036
-0.083661749958992
-0.08419184386730194
-0.10332071781158447
-0.19133490324020386
-0.1388649046421051
0.07743873447179794
0.1281544417142868
-0.0517808236181736
-0.03792349994182587


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.148258775472641
0.3445694148540497
-0.2827302813529968
-0.05256076529622078
-0.08120202273130417
-0.07812225818634033
-0.21972712874412537
0.05979793518781662
0.41577035188674927
0.011534709483385086
-0.16010035574436188
-0.311257541179657
-0.2576082944869995


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.009078297764062881
-0.24876263737678528
-0.2532755136489868
-0.4110489785671234
0.09316439926624298
-0.1810794174671173
-0.11837723851203918
-0.40853849053382874
0.25915461778640747
-0.13310286402702332
-0.3327951431274414
0.2126576006412506
-1.0602151155471802

starting epoch: 7


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

3.4463183879852295
0.20103926956653595
0.31835806369781494
0.614292562007904
0.401374489068985
0.2836058735847473
0.36457446217536926
0.27771106362342834
0.15497195720672607
0.16482990980148315
0.15249069035053253
0.17644371092319489
0.43781930208206177
0.1651611328125


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.2879638969898224
0.46555274724960327
0.14363974332809448
0.32661330699920654
0.162413090467453
0.31334465742111206
0.29047536849975586
0.2660342752933502
0.29792308807373047
0.1730434000492096
0.21784238517284393
0.27873700857162476
0.16384944319725037


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.01618102565407753
0.0146572794765234
0.12316206842660904
0.039228249341249466
0.07202907651662827
0.13501973450183868
0.20058627426624298
0.0563562735915184
0.032841119915246964
0.14400812983512878
0.16470414400100708
0.0292549729347229
0.15824584662914276


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.19219174981117249
0.09007225185632706
0.07005438208580017
0.01584484800696373
-0.025764646008610725
0.3454965054988861
0.02950301580131054
0.15052680671215057
-0.002077641896903515
-0.0010643303394317627
0.6869913935661316
0.04073730856180191
0.35219961404800415
0.15171492099761963


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.0034823804162442684
-0.015799695625901222
0.1512163281440735
-0.013305284082889557
-0.07526591420173645
-0.059420354664325714
0.0561496764421463
0.09376272559165955
-0.039967723190784454
0.0038414550945162773
-0.050773974508047104
-0.09265860915184021
-0.0759325698018074


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


-0.04935745894908905
0.08015289902687073
0.3074882924556732
0.45402413606643677
0.004599656909704208
-0.045457493513822556
-0.10156212747097015
-0.16737842559814453
-0.1346905678510666
-0.03456158936023712
0.6146731972694397
0.11925230175256729
0.010528176091611385


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.059934355318546295
-0.11407097429037094
0.575597882270813
-0.0480617918074131
-0.02107427828013897
-0.07308713346719742
0.145279660820961
0.12486004829406738
0.06378885358572006
-0.052319325506687164
0.0007379204034805298
-0.0810113400220871
-0.13332049548625946
-0.02415412850677967


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


-0.16654542088508606
0.17546910047531128
0.07844796031713486
-0.11723251640796661
-0.10731261223554611
-0.03076736256480217
-0.05868634581565857
-0.0935937687754631
-0.12168566137552261
0.05551939457654953
0.11232106387615204
-0.09480945020914078
-0.11911490559577942


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.06470119208097458
0.31191128492355347
-0.30861473083496094
-0.03996951878070831
0.01685730367898941
-0.14100675284862518
-0.2647610306739807
-0.0522146038711071
0.2781272828578949
-0.014012835919857025
-0.17925770580768585
-0.3568266034126282
-0.3297158181667328


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.03141433000564575
-0.20455285906791687
-0.2720262110233307
-0.30392611026763916
0.08646073937416077
-0.1112256646156311
-0.3131150007247925
-0.2964664697647095
0.14644619822502136
-0.2189001739025116
-0.3672955632209778
0.09995274990797043
-0.9374464154243469

starting epoch: 8


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

3.598088026046753
0.28554844856262207
0.33965930342674255
0.5395952463150024
0.30853840708732605
0.27051907777786255
0.349638968706131
0.2945331931114197
0.1482352614402771
0.06513139605522156
0.36655741930007935
0.1863354742527008
0.20896995067596436
0.15907613933086395


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.24493899941444397
0.42816784977912903
0.15450061857700348
0.3133394420146942
0.07282967865467072
0.24084460735321045
0.3323991894721985
0.23378682136535645
0.396644651889801
0.19003184139728546
0.2649044394493103
0.26407384872436523
0.2359165996313095


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.025694487616419792
0.026723133400082588
0.11126603186130524
0.016946393996477127
0.11898064613342285
0.11475564539432526
0.16840629279613495
-0.029671475291252136
0.047887399792671204
0.013916725292801857
0.09911841154098511
0.018620319664478302
0.1594756543636322


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.10518952459096909
0.04548212140798569
0.052675485610961914
-0.023220593109726906
-0.005549232475459576
0.30429738759994507
0.0384710319340229
0.10551607608795166
0.009466677904129028
0.04899642616510391
0.5238244533538818
0.0610148161649704
0.3689899146556854
0.10082171857357025


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


-0.01366778090596199
-0.05999847128987312
-0.01763129234313965
0.01272814255207777
-0.10557084530591965
-0.0872003436088562
0.021402757614850998
0.07723024487495422
-0.0013738498091697693
-0.08362605422735214
-0.07961016893386841
-0.07613202929496765
-0.046095412224531174


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


-0.07407387346029282
0.04927030950784683
0.1926238238811493
0.3521626591682434
-0.004054870456457138
-0.08391086012125015
-0.08162348717451096
-0.13298343122005463
-0.10348489880561829
-0.09282834082841873
0.6113987565040588
0.03245532885193825
0.022658443078398705


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


-0.03049144148826599
-0.13281473517417908
0.5465767979621887
-0.022327668964862823
-0.04912756010890007
-0.11352642625570297
0.09379555284976959
0.10850480198860168
0.13273480534553528
-0.03851182758808136
-0.05129121616482735
-0.026813507080078125
-0.06437928229570389
-0.06680186092853546


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


-0.20267567038536072
0.16655854880809784
-0.0034374184906482697
0.0835539847612381
-0.14738915860652924
-0.10328611731529236
-0.09194529056549072
-0.20456451177597046
-0.20529596507549286
0.06275750696659088
-0.014547161757946014
-0.08803471177816391
-0.1199767217040062


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


0.05068126693367958
0.3148282766342163
-0.10264937579631805
0.07165800034999847
-0.08901774883270264
-0.1842462420463562
-0.2110336720943451
-0.01734350621700287
0.27811312675476074
-0.07687254250049591
-0.22190922498703003
-0.3557244539260864
-0.2728058099746704


HBox(children=(FloatProgress(value=0.0, max=428.0), HTML(value='')))


-0.04354049637913704
-0.24774442613124847
-0.2405790090560913
-0.3620331287384033
0.2590668499469757
-0.24846340715885162
-0.35098567605018616
-0.4304982125759125
0.20386682450771332
-0.20529389381408691
-0.40080976486206055
-0.04051950201392174
-1.0197439193725586

starting epoch: 9


HBox(children=(FloatProgress(value=0.0, max=3980.0), HTML(value='')))

3.2461538314819336
0.18307146430015564



Traceback (most recent call last):
  File "/home/sampo/anaconda3/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/home/sampo/anaconda3/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/sampo/anaconda3/lib/python3.8/multiprocessing/connection.py", line 411, in _send_bytes
    self._send(header + buf)
  File "/home/sampo/anaconda3/lib/python3.8/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe


KeyboardInterrupt: 

True