# Первые пипы и скачивание датасета

In [1]:
!wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
!tar -xjf LJSpeech-1.1.tar.bz2

--2021-12-16 10:45:01--  https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
Resolving data.keithito.com (data.keithito.com)... 174.138.79.61
Connecting to data.keithito.com (data.keithito.com)|174.138.79.61|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2748572632 (2.6G) [application/octet-stream]
Saving to: ‘LJSpeech-1.1.tar.bz2’


2021-12-16 10:45:26 (110 MB/s) - ‘LJSpeech-1.1.tar.bz2’ saved [2748572632/2748572632]



In [2]:
!pip install librosa



In [3]:
!pip install torch==1.10.0+cu111 torchaudio==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html

Looking in links: https://download.pytorch.org/whl/torch_stable.html


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math

In [5]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# Featurizer

(дано)

In [6]:
from IPython import display
from dataclasses import dataclass

import torch
from torch import nn

import torchaudio

import librosa
from matplotlib import pyplot as plt


@dataclass
class MelSpectrogramConfig:
    sr: int = 22050
    win_length: int = 1024
    hop_length: int = 256
    n_fft: int = 1024
    f_min: int = 0
    f_max: int = 8000
    n_mels: int = 80
    power: float = 1.0

    # value of melspectrograms if we fed a silence into `MelSpectrogram`
    pad_value: float = -11.5129251


class MelSpectrogram(nn.Module):

    def __init__(self, config: MelSpectrogramConfig):
        super(MelSpectrogram, self).__init__()

        self.config = config

        self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
            sample_rate=config.sr,
            win_length=config.win_length,
            hop_length=config.hop_length,
            n_fft=config.n_fft,
            f_min=config.f_min,
            f_max=config.f_max,
            n_mels=config.n_mels
        )

        # The is no way to set power in constructor in 0.5.0 version.
        self.mel_spectrogram.spectrogram.power = config.power

        # Default `torchaudio` mel basis uses HTK formula. In order to be compatible with WaveGlow
        # we decided to use Slaney one instead (as well as `librosa` does by default).
        mel_basis = librosa.filters.mel(
            sr=config.sr,
            n_fft=config.n_fft,
            n_mels=config.n_mels,
            fmin=config.f_min,
            fmax=config.f_max
        ).T
        self.mel_spectrogram.mel_scale.fb.copy_(torch.tensor(mel_basis))

    def forward(self, audio: torch.Tensor) -> torch.Tensor:
        """
        :param audio: Expected shape is [B, T]
        :return: Shape is [B, n_mels, T']
        """

        mel = self.mel_spectrogram(audio) \
            .clamp_(min=1e-5) \
            .log_()

        return mel

In [7]:
featurizer = MelSpectrogram(MelSpectrogramConfig())

---

# Dataset

Датасет был честно скоммунизден из репозитория HIFI https://github.com/jik876/hifi-gan

In [77]:
import math
import os
import random
import torch
import torch.utils.data
import numpy as np
from librosa.util import normalize
from scipy.io.wavfile import read
from librosa.filters import mel as librosa_mel_fn

MAX_WAV_VALUE = 32768.0


def load_wav(full_path):
    sampling_rate, data = read(full_path)
    return data, sampling_rate


def dynamic_range_compression(x, C=1, clip_val=1e-5):
    return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)


def dynamic_range_decompression(x, C=1):
    return np.exp(x) / C


def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
    return torch.log(torch.clamp(x, min=clip_val) * C)


def dynamic_range_decompression_torch(x, C=1):
    return torch.exp(x) / C


def spectral_normalize_torch(magnitudes):
    output = dynamic_range_compression_torch(magnitudes)
    return output


def spectral_de_normalize_torch(magnitudes):
    output = dynamic_range_decompression_torch(magnitudes)
    return output


mel_basis = {}
hann_window = {}


def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
    if torch.min(y) < -1.:
        print('min value is ', torch.min(y))
    if torch.max(y) > 1.:
        print('max value is ', torch.max(y))

    global mel_basis, hann_window
    if fmax not in mel_basis:
        mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
        mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
        hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)

    y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
    y = y.squeeze(1)

    spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
                      center=center, pad_mode='reflect', normalized=False, onesided=True)

    spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))

    spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
    spec = spectral_normalize_torch(spec)

    return spec


###Изменила
def get_dataset_filelist(path_to_dataset, filename_train, filename_val=None):
    input_wavs_dir = path_to_dataset + "/wavs"
    with open(filename_train, 'r', encoding='utf-8') as fi:
        training_files = [os.path.join(input_wavs_dir, x.split('|')[0] + '.wav')
                          for x in fi.read().split('\n') if len(x) > 0]
    if filename_val is None:
      return training_files

    with open(filename_val, 'r', encoding='utf-8') as fi:
        validation_files = [os.path.join(input_wavs_dir, x.split('|')[0] + '.wav')
                            for x in fi.read().split('\n') if len(x) > 0]
    return training_files, validation_files



class MelDataset(torch.utils.data.Dataset):
    def __init__(self, training_files, segment_size, n_fft, num_mels,
                 hop_size, win_size, sampling_rate,  fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
                 device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
        self.audio_files = training_files
        random.seed(1234)
        if shuffle:
            random.shuffle(self.audio_files)
        self.segment_size = segment_size
        self.sampling_rate = sampling_rate
        self.split = split
        self.n_fft = n_fft
        self.num_mels = num_mels
        self.hop_size = hop_size
        self.win_size = win_size
        self.fmin = fmin
        self.fmax = fmax
        self.fmax_loss = fmax_loss
        self.cached_wav = None
        self.n_cache_reuse = n_cache_reuse
        self._cache_ref_count = 0
        self.device = device
        self.fine_tuning = fine_tuning
        self.base_mels_path = base_mels_path

    def __getitem__(self, index):
        filename = self.audio_files[index]
        if self._cache_ref_count == 0:
            audio, sampling_rate = load_wav(filename)
            audio = audio / MAX_WAV_VALUE
            if not self.fine_tuning:
                audio = normalize(audio) * 0.95
            self.cached_wav = audio
            if sampling_rate != self.sampling_rate:
                raise ValueError("{} SR doesn't match target {} SR".format(
                    sampling_rate, self.sampling_rate))
            self._cache_ref_count = self.n_cache_reuse
        else:
            audio = self.cached_wav
            self._cache_ref_count -= 1

        audio = torch.FloatTensor(audio)
        audio = audio.unsqueeze(0)

        if not self.fine_tuning:
            if self.split:
                if audio.size(1) >= self.segment_size:
                    max_audio_start = audio.size(1) - self.segment_size
                    audio_start = random.randint(0, max_audio_start)
                    audio = audio[:, audio_start:audio_start+self.segment_size]
                else:
                    audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')

            mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
                                  self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
                                  center=False)
        else:
            mel = np.load(
                os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
            mel = torch.from_numpy(mel)

            if len(mel.shape) < 3:
                mel = mel.unsqueeze(0)

            if self.split:
                frames_per_seg = math.ceil(self.segment_size / self.hop_size)

                if audio.size(1) >= self.segment_size:
                    mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
                    mel = mel[:, :, mel_start:mel_start + frames_per_seg]
                    audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
                else:
                    mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
                    audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')

        mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
                                   self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
                                   center=False)

        return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())

    def __len__(self):
        return len(self.audio_files)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [78]:
training_files = get_dataset_filelist('/content/LJSpeech-1.1', '/content/one_batch_man.txt')

In [79]:
training_files

['/content/LJSpeech-1.1/wavs/LJ001-0001.wav',
 '/content/LJSpeech-1.1/wavs/LJ001-0002.wav',
 '/content/LJSpeech-1.1/wavs/LJ001-0003.wav',
 '/content/LJSpeech-1.1/wavs/LJ001-0004.wav',
 '/content/LJSpeech-1.1/wavs/LJ001-0005.wav']

In [80]:
train_dataset  = MelDataset(training_files, 
                            segment_size=8192, 
                            n_fft =1024, 
                            num_mels = 80,
                            hop_size = 256, 
                            win_size = 1024, 
                            sampling_rate = 22050,  
                            fmin = 0, 
                            fmax = 8000 
                            )

In [81]:
dataloader = DataLoader(train_dataset, batch_size=5)

---

---

# Grapheme Aligner

In [11]:
@dataclass
class Point:
    token_index: int
    time_index: int
    score: float


@dataclass
class Segment:
    label: str
    start: int
    end: int
    score: float

    def __repr__(self):
        return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"

    @property
    def length(self):
        return self.end - self.start


class GraphemeAligner(nn.Module):

    def __init__(self):
        super().__init__()

        self._wav2vec2 = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H.get_model()
        self._labels = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H.get_labels()
        self._char2index = {c: i for i, c in enumerate(self._labels)}
        self._unk_index = self._char2index['<unk>']
        self._resampler = torchaudio.transforms.Resample(
            orig_freq=MelSpectrogramConfig.sr, new_freq=16_000
        )

    def _decode_text(self, text):
        text = text.replace(' ', '|').upper()
        return torch.tensor([
            self._char2index.get(char, self._unk_index)
            for char in text
        ]).long()

    @torch.no_grad()
    def forward(
        self,
        wavs: torch.Tensor,
        wav_lengths: torch.Tensor,
        texts: Union[str, List[str]]
    ):
        if isinstance(texts, str):
            texts = [texts]
        batch_size = wavs.shape[0]

        durations = []
        for index in range(batch_size):
            current_wav = wavs[index, :wav_lengths[index]].unsqueeze(dim=0)
            current_wav = self._resampler(current_wav)
            emission, _ = self._wav2vec2(current_wav)
            emission = emission.log_softmax(dim=-1).squeeze(dim=0).cpu()

            tokens = self._decode_text(texts[index])

            trellis = self._get_trellis(emission, tokens)
            path = self._backtrack(trellis, emission, tokens)
            segments = self._merge_repeats(texts[index], path)

            num_frames = emission.shape[0]
            relative_durations = torch.tensor([
                segment.length / num_frames for segment in segments
            ])

            durations.append(relative_durations)
            
        durations = pad_sequence(durations).transpose(0, 1)
        return durations

    def _get_trellis(self, emission, tokens, blank_id=0):
        num_frame = emission.size(0)
        num_tokens = len(tokens)

        # Trellis has extra dimension for both time axis and tokens.
        # The extra dim for tokens represents <SoS> (start-of-sentence)
        # The extra dim for time axis is for simplification of the code.
        trellis = torch.full((num_frame + 1, num_tokens + 1), -float('inf'))
        trellis[:, 0] = 0
        for t in range(num_frame):
            trellis[t + 1, 1:] = torch.maximum(
                # Score for staying at the same token
                trellis[t, 1:] + emission[t, blank_id],

                # Score for changing to the next token
                trellis[t, :-1] + emission[t, tokens],
            )
        return trellis

    def _backtrack(self, trellis, emission, tokens, blank_id=0):
        # Note:
        # j and t are indices for trellis, which has extra dimensions
        # for time and tokens at the beginning.
        # When refering to time frame index `T` in trellis,
        # the corresponding index in emission is `T-1`.
        # Similarly, when refering to token index `J` in trellis,
        # the corresponding index in transcript is `J-1`.
        j = trellis.size(1) - 1
        t_start = torch.argmax(trellis[:, j]).item()

        path = []
        for t in range(t_start, 0, -1):
            # 1. Figure out if the current position was stay or change
            # Note (again):
            # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
            # Score for token staying the same from time frame J-1 to T.
            stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
            # Score for token changing from C-1 at T-1 to J at T.
            changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]

            # 2. Store the path with frame-wise probability.
            prob = emission[t - 1, tokens[j - 1]
                            if changed > stayed else 0].exp().item()
            # Return token index and time index in non-trellis coordinate.
            path.append(Point(j - 1, t - 1, prob))

            # 3. Update the token
            if changed > stayed:
                j -= 1
                if j == 0:
                    break

        else:
            raise ValueError('Failed to align')

        return path[::-1]

    def _merge_repeats(self, text, path):
        i1, i2 = 0, 0
        segments = []
        while i1 < len(path):
            while i2 < len(path) and path[i1].token_index == path[i2].token_index:
                i2 += 1
            score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
            segments.append(
                Segment(
                    text[path[i1].token_index],
                    path[i1].time_index,
                    path[i2 - 1].time_index + 1,
                    score
                )
            )
            i1 = i2

        return segments

    @staticmethod
    def plot_trellis_with_path(trellis, path):
        # to plot trellis with path, we take advantage of 'nan' value
        trellis_with_path = trellis.clone()
        for i, p in enumerate(path):
            trellis_with_path[p.time_index, p.token_index] = float('nan')
        plt.imshow(trellis_with_path[1:, 1:].T, origin='lower')

In [12]:

aligner = GraphemeAligner().to(DEVICE)

Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth" to /root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth


  0%|          | 0.00/360M [00:00<?, ?B/s]

# HIFI

Здесть пока только генератор

In [13]:
def get_padding(kernel_size, dilation=1):
    return int((kernel_size * dilation - dilation) / 2)

In [109]:
class ResBlock(nn.Module):
  def __init__(self, 
               D_r, # Матрица dilation rates
               k_r, # Вектор kernel sizes
               channels,
               slope=0.1
               ):
    super(ResBlock, self).__init__()

    assert len(D_r.shape) == 2, 'D_r shape is ' + str(D_r.shape)
    assert len(k_r.shape) == 1, 'k_r shape is ' + str(k_r.shape)

    net = []
    for m in range(len(D_r)):
      for l in range(len(D_r[m])):

        net.append(nn.Sequential(*[
                                nn.LeakyReLU(slope),
                                nn.Conv1d(channels, 
                                          channels,
                                          kernel_size=k_r[m],
                                          dilation=D_r[m][l],
                                          padding=get_padding(k_r[m], D_r[m][l])
                                          )
                               ])
                         )
    #Я помню эту ошибку из предыдущей домашки
    self.net = nn.ModuleList(net)

  #Вроде надо последовательно подавать в блоки, не забывая прибавлять skip-connection
  #TODO понять, так ли это  
  def forward(self, x):
    for block in self.net:
      #x += block(x) эта сволочь не работает
      #Костыли
      out = block(x)
      x = x + out
    return x

In [39]:
import numpy as np
k_u = np.array([16, 16, 4, 4])
k_r = np.array([3, 7, 11])
D_r = np.array([[[1, 1], [3, 1], [5, 1]], [[1, 1], [3, 1], [5, 1]], [[1, 1], [3, 1], [5, 1]]]  )

In [40]:
res = ResBlock( 
               D_r[0], # Матрица dilation rates
               k_r, # Вектор kernel sizes
               3
               )

x = torch.rand(8,3,128)

In [41]:
x_res = res(x)
x_res.shape

torch.Size([8, 3, 128])

In [44]:
class MRFBlock(nn.Module):

  def __init__(self, 
               channels,
               D_r, # 3D
               k_r # 1D
               ):
    super(MRFBlock, self).__init__()

    #Там будет деление на эту штуку
    #self.num_kernels = len(h.resblock_kernel_sizes)
    #Вроде это оно
    self.num_kernels = len(k_r)

    assert len(D_r.shape) == 3, 'D_r shape is ' + str(D_r.shape)
    assert len(k_r.shape) == 1, 'k_r shape is ' + str(k_r.shape)
    
    net = []
    for n in range(self.num_kernels):
      net.append(ResBlock(D_r[n],
                          k_r,
                          channels))
    #Я помню эту ошибку из предыдущей домашки
    self.net = nn.ModuleList(net)
  
  def forward(self, x):
    #Надо прогнать сквозь блоки, сложить и разделить
    out = 0
    for block in self.net:
      out += block(x)
    return out / self.num_kernels


In [45]:
mrf = MRFBlock(3,
               D_r, # 3D
               k_r # 2D
               )

In [46]:
mrf_out = mrf(x)

In [47]:
mrf_out.shape

torch.Size([8, 3, 128])

In [21]:
r=torch.rand(1,3,3)

In [25]:
r * 2

tensor([[[0.2320, 1.9196, 1.0244],
         [0.8521, 1.2485, 1.0211],
         [0.8985, 0.2924, 1.8444]]])

In [24]:
t = 0
t += r
t+= r
t

tensor([[[0.2320, 1.9196, 1.0244],
         [0.8521, 1.2485, 1.0211],
         [0.8985, 0.2924, 1.8444]]])

In [74]:
class GeneratorModel(nn.Module):
  def __init__(self,
               h_u, #hidden dimension
               k_u, #kernel size for conv transposed
               D_r,
               k_r,
               slope=0.1
               ):
    super(GeneratorModel, self).__init__()

    #Это у нас первый блок. 
    #Захардкодила чиселки (
    self.first = nn.Conv1d(80, h_u, 7, 1, padding=3)  
    
    net = []
    in_channels = h_u

    for i in range(len(k_u)):
      out_channels = in_channels // 2

      net.append(nn.Sequential(*[
                                 nn.LeakyReLU(slope),
                                 #В мдз5 по глубинному обучению я использовала Upsample
                                 #Это было ошибкой
                                 nn.ConvTranspose1d(in_channels=in_channels,
                                                    out_channels=out_channels,
                                                    kernel_size=k_u[i],
                                                    stride = k_u[i] // 2,
                                                    padding = (k_u[i] - k_u[i] // 2) // 2
                                                    ),
                                 MRFBlock(out_channels, D_r, k_r)
                                 ])
                 )
      in_channels = out_channels
    
    self.net = nn.Sequential(*net)

    self.last = nn.Sequential(*[
                                nn.LeakyReLU(slope),
                                nn.Conv1d(in_channels, 1, 7, 1, padding=3),
                                nn.Tanh()
                                ])
    
  def forward(self, x):
    return self.last(self.net(self.first(x)))
    




In [97]:
n_fft =1024
num_mels = 80
hop_size = 256
win_size = 1024
sampling_rate = 22050
fmin = 0
fmax = 8000 



In [70]:
k_u = np.array([16, 16, 4, 4])
k_r = np.array([3, 7, 11])
D_r = np.array([[[1, 1], [3, 1], [5, 1]], [[1, 1], [3, 1], [5, 1]], [[1, 1], [3, 1], [5, 1]]]  )
g = GeneratorModel(
               64, #hidden dimension
               k_u, #kernel size for conv transposed
               D_r,
               k_r
               )

In [71]:
x = torch.rand(8,80,15)

In [72]:
ou = g(x)

In [73]:
ou.shape

torch.Size([8, 32, 120])

# Функции для тренировки 

## Логгирование аудио

Вроде можно как-то логгировать аудио прямо из тензора, но почему-то оно не работает.

Костыль: сохраню тензор в файлик, залоггирую файлик, удалю файлик.

In [82]:
import os
def add_audio_to_wandb(wav, transcript='', train=True):
    if train:
        key = 'train'
    else:
        key = 'val'

    path = 'audio_log_file.wav'
    with open(path, "wb") as f:
        f.write(wav.data)
    wandb.log(
          {key + '_audio': wandb.Audio(path, sample_rate=22050),
           key + '_transcript' : wandb.Html(transcript)})
    os.remove(path)

## Функция обучения

In [86]:
g = GeneratorModel(h_u=h_u, k_u=k_u,D_r=D_r,k_r=k_r)

In [113]:
from tqdm import tqdm
def train_g(model,
          loss_func,
          n_epochs,
          opt,
          train_loader,
          num_model=0,
          logging=False
          ):
    
    for epoch in range(n_epochs):
        train_loss  = train_epoch_g(model,
                    loss_func,
                    opt,
                    train_loader,
                    num_model,
                    logging)
        print('\nepoch ' + str(epoch) + '/' + str(n_epochs) + ' train loss = ' + str(train_loss))
        torch.save(model.state_dict(), 'model_' + str(num_model) +'.pt')
        torch.save(model.state_dict(), '/content/drive/MyDrive/hifi/model_' + str(num_model) +'.pt')
        if logging:
            wandb.log({
              'train_loss_epoch' : train_loss
            })
        
        
        
def train_epoch_g(model,
                loss_func,
                opt,
                train_loader,
                num_model=0,
                logging=False):
    model.train()
    model = model.to(DEVICE)

    torch.autograd.set_detect_anomaly(True)

    loss_sum = 0

    n_fft =1024
    num_mels = 80
    hop_size = 256
    win_size = 1024
    sampling_rate = 22050
    fmin = 0
    fmax = 8000 
    
    for mel, audio, filename, mel_loss in train_loader:

        mel = mel.to(DEVICE)
        audio = audio.to(DEVICE)

        opt.zero_grad()

        preds = model(mel)
        mel_preds = mel_spectrogram(preds.squeeze(1), 
                                    n_fft=n_fft,
                                    num_mels=num_mels,
                                    hop_size=hop_size,
                                    win_size=win_size,
                                    sampling_rate=sampling_rate,
                                    fmin = fmin,
                                    fmax=fmax).to(DEVICE)

        loss = loss_func(mel, mel_preds) * 45
        loss.backward()

        opt.step()

        
        loss_sum += loss.item()

        if logging:
            wandb.log({
              'train_loss_step' : loss.item(),
            })

        
    
    return loss_sum / len(train_loader)



     

In [90]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [110]:
#!g1.1
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
h_u = 128
D_r = np.array([[[1, 1], [3, 1], [5, 1]], [[1, 1], [3, 1], [5, 1]], [[1, 1], [3, 1], [5, 1]]]) 
k_r = np.array([3, 7, 11])
k_u = np.array([16, 16, 4, 4])  
 
g = GeneratorModel(h_u=h_u, k_u=k_u,D_r=D_r,k_r=k_r)
                
loss_func =  nn.L1Loss()
opt = torch.optim.Adam(g.parameters(), lr=0.0002, betas=(0.8, 0.99))

In [None]:
!pip install wandb

In [95]:
import wandb
wandb.login()
wandb.init(project="hifi L1 batch")

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mkogana00[0m (use `wandb login --relogin` to force relogin)


In [114]:
#У этой ячейки нет output, так как эта реализует весь процесс обучения, который быд на самом деле разбит
n_epochs = 2000
num_model = 0 #Это аргумент для сохранения модели, чтобы сохранялось в файл model_*num_model*.pt

train_g(model=g,
          loss_func=loss_func,
          n_epochs=n_epochs,
          opt=opt,
          train_loader=dataloader,
          num_model=num_model,
          logging=True
          )


epoch 0/2000 train loss = 65.42745971679688

epoch 1/2000 train loss = 72.67805480957031

epoch 2/2000 train loss = 75.03805541992188

epoch 3/2000 train loss = 77.2337417602539

epoch 4/2000 train loss = 71.04646301269531

epoch 5/2000 train loss = 76.57867431640625

epoch 6/2000 train loss = 61.33638000488281

epoch 7/2000 train loss = 64.98736572265625

epoch 8/2000 train loss = 74.51143646240234

epoch 9/2000 train loss = 60.715240478515625

epoch 10/2000 train loss = 81.49251556396484

epoch 11/2000 train loss = 70.57463836669922

epoch 12/2000 train loss = 64.55099487304688

epoch 13/2000 train loss = 59.597816467285156

epoch 14/2000 train loss = 61.22316360473633

epoch 15/2000 train loss = 67.6251220703125

epoch 16/2000 train loss = 56.71894836425781

epoch 17/2000 train loss = 66.066162109375

epoch 18/2000 train loss = 71.63614654541016

epoch 19/2000 train loss = 58.70157241821289

epoch 20/2000 train loss = 58.85361862182617

epoch 21/2000 train loss = 58.15373611450195


In [115]:
n_epochs = 2000
num_model = 1 #Это аргумент для сохранения модели, чтобы сохранялось в файл model_*num_model*.pt

train_g(model=g,
          loss_func=loss_func,
          n_epochs=n_epochs,
          opt=opt,
          train_loader=dataloader,
          num_model=num_model,
          logging=True
          )


epoch 0/2000 train loss = 23.833881378173828

epoch 1/2000 train loss = 21.68206214904785

epoch 2/2000 train loss = 26.35544776916504

epoch 3/2000 train loss = 23.75341033935547

epoch 4/2000 train loss = 24.74690055847168

epoch 5/2000 train loss = 22.23787498474121

epoch 6/2000 train loss = 25.34450340270996

epoch 7/2000 train loss = 22.29012107849121

epoch 8/2000 train loss = 24.003923416137695

epoch 9/2000 train loss = 24.059017181396484

epoch 10/2000 train loss = 24.29172134399414

epoch 11/2000 train loss = 22.668872833251953

epoch 12/2000 train loss = 23.247844696044922

epoch 13/2000 train loss = 22.36225700378418

epoch 14/2000 train loss = 23.64216423034668

epoch 15/2000 train loss = 23.08078956604004

epoch 16/2000 train loss = 24.15790367126465

epoch 17/2000 train loss = 22.625537872314453

epoch 18/2000 train loss = 27.000703811645508

epoch 19/2000 train loss = 22.47965431213379

epoch 20/2000 train loss = 26.561460494995117

epoch 21/2000 train loss = 22.76163

In [116]:
wandb.finish()

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_loss_epoch,█▅▄▄▄▃▃▃▃▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁
train_loss_step,█▄▄▄▃▃▃▂▂▃▂▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_loss_epoch,19.2858
train_loss_step,19.2858
