In [46]:
import torch, torchvision
import numpy as np
import math
import os
import matplotlib.pyplot as plt
import pandas as pd
import torchaudio
from torch import nn
from torchvision import transforms
from torch.nn.utils.rnn import pad_sequence
from torch.cuda.amp import autocast
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from pathlib import Path
from tqdm.auto import tqdm
from torch import amp
from torch.amp import GradScaler
from torch.utils.data import DataLoader, SubsetRandomSampler, RandomSampler
from torch.optim.swa_utils import AveragedModel    
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn.functional import pad

import torchvision.utils as vutils
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'

In [47]:
device

device(type='cuda')

In [48]:
config = {
    "dataset": {
        "train": {
            "table": "E:/data/train.csv",
            "data": "E:/data/bare_data/"
        },
        "val": {
            "table": "E:/data/val.csv",
            "data": "E:/data/bare_data/"
        }
    },
    "train": {
        "batch_size": 16,
        "grad_acum": 1,
        "dtype": "float32",
        'shuffle': True,
        'pin_memory': True,
    },
    "val": {
        "batch_size": 16,
        "grad_acum": 1,
        "dtype": "float32",
        'shuffle': False,
        'pin_memory': True,
    },
    "vae": {
        "freq": 16000,
        "lenght": 5,
    },
    "model": {
        "latent_size": 128,
        "epochs": 15,
        "learning_rate": 0.001,
        "freq_scale": 4,
        "time_scale": 4,
    },
    "utils": {
        "n_fft": 800, # TODO
    }
}

# VAE

In [49]:
class VAE_Audio(nn.Module):
    def __init__(self,):
        super().__init__()
        self.encoder_input = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.GELU(),
            nn.Conv2d(32, 16, 3, 1, 1), nn.BatchNorm2d(16), nn.GELU(),
        )
        self.encoder_squeeze = nn.Sequential(
            nn.Conv2d(16, 16, 3, 2, 1), nn.GELU(),
            nn.Conv2d(16, 16, 3, 2, 1), nn.GELU(),
        )
        self.encoder_mu     = nn.Conv2d(16, 32, 1)
        self.encoder_logvar = nn.Conv2d(16, 32, 1)
        self.decoder_unsqueeze = nn.Sequential(
                nn.ConvTranspose2d(32, 32, 3, 2, 1, output_padding=1), nn.GELU(),
                nn.ConvTranspose2d(32, 16, 3, 2, 1, output_padding=1), nn.GELU(),
        )
        self.decoder_output = nn.Sequential(
                 nn.ConvTranspose2d(16, 4, 3, 1, 1), nn.GELU(),
                 nn.BatchNorm2d(4),
                 nn.ConvTranspose2d(4, 1, 3, 1, 1),
        )
    def encode(self, x):
        x = self.encoder_input(x)
        x = self.encoder_squeeze(x)
        mu = self.encoder_mu(x)
        logvar = self.encoder_logvar(x)
        return mu, logvar
    def sample(self, x):
        mu, logvar = self.encode(x)
        std = torch.exp(0.5 * logvar)
        z = mu + std * torch.randn_like(std)
        return z, mu, logvar
    def decode(self, x):
        x = self.decoder_unsqueeze(x)
        x = self.decoder_output(x)
        return x
    def KLD_loss(self, mu, logvar, q=0.02):
        kld = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
        kld = torch.clamp(kld, min=q)
        return kld.mean()
    def forward(self, x):
        z, mu, logvar = self.sample(x)
        return self.decode(z), z, mu, logvar

# Моделька

In [50]:
class NoiseScheduler:
    def __init__(self, timestamps=100, epochs=100):
        self.steps = timestamps
        self.epochs = epochs
        self.betas = torch.linspace(0.02, 0.0004, timestamps)
        self.alpha = 1 - self.betas
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def __call__(self, step):
        return self.betas[step]
    def corrupt_image(self, image, idx):
        noise = torch.randn_like(image)
        b = image.shape[0]
        k = self.alpha_hat[idx].view(b, 1, 1, 1)
        return torch.sqrt(k)*image + torch.sqrt(1-k)*noise, noise

    def sample_timestamps(self, iters=10):
        indicies = torch.cat((torch.randint(0, self.steps-1, [iters-1]), torch.tensor(100)))
        return indicies

    def restore_image(self, image, pred, idx, sigma=0.0):
        noise = torch.randn_like(image)
        alpha = self.alpha[idx]
        alpha_hat = self.alpha_hat[idx]
        nalpha = 1-alpha
        nalpha_hat = 1-alpha_hat
        return (image - pred*nalpha/(torch.sqrt(nalpha_hat)))/torch.sqrt(alpha) + sigma*noise

    def get_idx(self, epoch, batch_size):
        return self.steps-torch.randint(0, self.steps, [batch_size]) - 1

In [51]:
class ResudialBlock(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.layers = nn.Sequential(*args)
    def forward(self, x):
        return x + self.layers(x)

class ConditionMixingLayer(nn.Module):
    def __init__(self, input_channels, conditioning_length, hidden_size = 8):
        super().__init__()
        self.hidden_size = hidden_size
        self.cond_proj = nn.Linear(conditioning_length, self.hidden_size)
        self.conv_proj = nn.Conv2d(input_channels, self.hidden_size, 3, 1, 1)
        self.lin_proj = nn.Linear(self.hidden_size, self.hidden_size)
        self.lin1_unproj = nn.Linear(self.hidden_size, self.hidden_size)
        self.lin2_unproj = nn.Linear(self.hidden_size, self.hidden_size)
        self.conv_unproj = nn.Conv2d(self.hidden_size, input_channels, 3, 1, 1)
        self.conv_act = nn.Tanh()
        self.bn1 = nn.BatchNorm2d(input_channels)
        self.bn2 = nn.BatchNorm2d(input_channels)


        self.add_a = nn.Linear(self.hidden_size, self.hidden_size)
        self.add_b = nn.Linear(self.hidden_size, self.hidden_size)

    def forward(self, x, c=None, skip=False):
        if len(x.shape)==3:
            x = x.unsqueeze(0)
        b, ch, h, w = x.shape
        x = self.bn1(x)
        xn = self.conv_proj(x) # [B, N, H, W]
        xn = self.conv_act(xn)
        xn = xn.view(b, h*w, self.hidden_size) # [B, H*W, N]
        xn = self.lin_proj(xn) # [B, H*W, N]
        xn = self.conv_act(xn) # [B, H*W, N]
        if not skip:
            cn = self.cond_proj(c) # [B, N]
            cn = self.conv_act(cn) # [B, N]
            cn = cn.view(b, 1, self.hidden_size)
            xn = self.add_a(xn) # [B, 1, N]
            cn = self.add_b(cn) # [B, H*W, N]
            xn = xn + cn # [B, H*W, N]

        xn = self.lin1_unproj(xn) # [B, H*W, N]
        xn = self.conv_act(xn) # [B, H*W, N]
        xn = self.lin2_unproj(xn) # [B, H*W, N]
        xn = self.conv_act(xn) # [B, H*W, N]

        xn = xn.view(b, self.hidden_size, h, w) # [B, N, H, W]
        xn = self.conv_unproj(xn)
        xn = self.conv_act(xn) # [B, I, H, W]
        xn = self.bn2(xn)

        x = xn + x

        return x

In [52]:
class OurAttentionLayer(nn.Module):
    def __init__(self, patch_size, channels_in, hidden_dim, emb_size=1, cross=False):
        super().__init__()
        self.cross = cross
        self.hidden_dim = hidden_dim #N
        self.channels_in = channels_in #C
        self.patch_size = patch_size #pq
        self.emb_size = emb_size #Z
        self.Wk = nn.Linear(patch_size, hidden_dim)     # [pq, N]
        self.Wv = nn.Linear(patch_size, hidden_dim)     # [pq, N]
        self.LN = nn.LayerNorm([channels_in, patch_size])
        if cross:
            self.Wi = nn.Linear(emb_size, channels_in)  # [Z, C]
            self.Wj = nn.Linear(emb_size, hidden_dim)   # [Z, N]
            self.Wq = nn.Linear(hidden_dim, hidden_dim) # [N, N]
        else:
            self.Wq = nn.Linear(patch_size, hidden_dim) # [C, N]
        self.Wr = nn.Linear(hidden_dim, patch_size)
        self.softmax = nn.Softmax(dim=-1)
        self.dscale = 1/(hidden_dim**0.5)
    def forward(self, image, text = None, ret_attn_QKV=False):
        # image == [Batch, channels, patch_size] == [..., C, pq]
        K = self.Wk(image) # [..., C, pq] * [pq, N] = [..., C, N]
        V = self.Wv(image)
        if self.cross and text is None:
            text = torch.rand(1, self.emb_size)
        if self.cross:
            #text_T = torch.permute(text, (-1, -2))
            # text = [Batch, seq_len, emb_size] == [..., S, Z]
            I = self.Wi(text) # [..., S, Z] * [Z, C] -> [..., S, C]
            J = self.Wj(text) # [..., S, Z] * [Z, N] -> [..., S, N]
            Q1 = torch.einsum("...sc,...sn->...cn", I, J) # возможно надо отдебажить учитывая Batch и прочее
            # [..., C, S] * [..., S, N] -> C, N
            Q = self.Wq(Q1).unsqueeze(1).expand_as(K) # -> C, N
        else:
            Q = self.Wq(image) # [..., C, pq] * [pq, N] = [..., C, N]

        qk = torch.einsum("...jn,...cn->...cj", Q, K)
        R = self.softmax(qk*self.dscale)
        R = torch.einsum("...ic,...cn->...in", R, V) # Scaled Dot-Product Attention
        O = self.Wr(R) # [..., C, N] * [N, pq] -> [..., C, pq]
        O = O + image
        O = self.LN(O)
        if ret_attn_QKV:
            return O, Q, K, V
        return O

class PatchImage(nn.Module):
    def __init__(self, patch_size, reverse=False):
        super().__init__()
        self.patch_size = patch_size
        self.n = int(self.patch_size**(0.5))
        assert self.n**2 == patch_size, "patch_size must be full square"
        self.reverse = reverse
    def forward(self, x):

        n = self.n
        if self.reverse:
            b, c, h, w, s = x.shape
            x = torch.reshape(x, (b, c, h, w, n, n))
            x = torch.transpose(x, -2, -3)
            x = torch.reshape(x, (b, c, h*n, w*n))
            return x
        b, c, h, w = x.shape
        x = torch.reshape(x, (b, c, h//n, n, w//n, n))
        x = torch.transpose(x, -2, -3)
        x = torch.reshape(x, (b, c, h//n, w//n, n*n))
        return x
        #torch.reshape(torch.transpose(torch.reshape(a, (b, c, h//n, n, w//n, n)), -2, -3), (b, c, h//n, w//n, n*n))

In [53]:
class MNIST_diffusion(nn.Module):
    def __init__(self, input_channels=1, conditioning_length=1, timestamp_length=1, hidden_dims=32, mixin_dims=32):
        super().__init__()
        self.input_scaler = nn.Sequential(
                nn.Conv2d(in_channels=input_channels, out_channels=hidden_dims, kernel_size=1), nn.Tanh()
            )
        self.output_scaler = nn.Sequential(
                nn.Conv2d(in_channels=hidden_dims, out_channels=hidden_dims, kernel_size=1), nn.Tanh(),
                nn.Conv2d(in_channels=hidden_dims, out_channels=input_channels, kernel_size=1)
            )
        self.precode = nn.Sequential(
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
            )
        self.mixing_layer = ConditionMixingLayer(hidden_dims, conditioning_length+timestamp_length, mixin_dims)
        self.encoder = nn.Sequential(
            ResudialBlock(
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
            ),
            nn.Sequential(
                nn.Conv2d(hidden_dims, hidden_dims*4, 1), nn.Tanh(), nn.LazyInstanceNorm2d(),
                nn.PixelShuffle(2), #C/4
                nn.Conv2d(hidden_dims, hidden_dims, 1), nn.Tanh(), nn.LazyInstanceNorm2d(),
                nn.PixelUnshuffle(2), #C*4
                nn.Conv2d(hidden_dims*4, hidden_dims, 1), nn.Tanh(), nn.LazyInstanceNorm2d(),
            ),
            ResudialBlock(
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.LazyInstanceNorm2d(),
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.LazyInstanceNorm2d(),
            ),
        )
        self.downscaler = nn.Sequential(
            nn.Conv2d(hidden_dims, hidden_dims, 3, 2, 1), nn.BatchNorm2d(hidden_dims), nn.Tanh(),
            nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
            nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
        )

        self.upscaler = nn.Sequential(
            nn.ConvTranspose2d(hidden_dims, hidden_dims, 3, 2, 1, output_padding=1), nn.BatchNorm2d(hidden_dims), nn.Tanh(),
            nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
            nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
        )
        self.decoder = nn.Sequential(
            ResudialBlock(
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
            ),
            nn.Sequential(
                nn.Conv2d(hidden_dims, hidden_dims*4, 1), nn.Tanh(), nn.LazyInstanceNorm2d(),
                nn.PixelShuffle(2), #C/4
                nn.Conv2d(hidden_dims, hidden_dims, 1), nn.Tanh(), nn.LazyInstanceNorm2d(),
                nn.PixelUnshuffle(2), #C*4
                nn.Conv2d(hidden_dims*4, hidden_dims, 1), nn.Tanh(), nn.LazyInstanceNorm2d(),
            ),
            ResudialBlock(
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
            ),
        )
        self.padder = nn.ZeroPad2d(1)
    def encode(self, x, mix = False, c = None):
        x = self.input_scaler(x)
        x = self.precode(x)
        if mix and not c is None:
            x = self.mixing_layer(x, c, False)
        else:
            x = self.mixing_layer(x, c, True)
        x = self.encoder(x)
        return x
    def rescale(self, x, n=0):
        if n<=0:
            return x
        b, c, ho, wo = x.shape
        if not (ho%2==0 and wo%2==0):
            y = self.padder(x)
        else:
            y = x
        y = self.downscaler(y)
        b, c, h, w = y.shape
        z = self.rescale(y, n-1)
        z = nn.functional.interpolate(z, [h, w])
        y = y + z
        y = self.upscaler(y)
        y = nn.functional.interpolate(y, [ho, wo])
        x = y + x
        return x
    def decode(self, x):
        x = self.decoder(x)
        x = self.output_scaler(x)
        return x
    def forward(self, x, mix=False, c = None, n=1):
        x = self.encode(x, mix, c)
        x = self.rescale(x, n)
        x = self.decode(x)
        return x


In [54]:
class Word_Encoder(nn.Module): # Токенайзер + Эмбеддер для букв
    def __init__(self, alphabet, emb_size, max_word_size = 256):
        super().__init__()
        self.alphabet = list(alphabet)+["<pad>", "<stress>", "<unk>"] # буквы + спец токены: пустой, ударение и неизвестный символ
        self.emb_size = emb_size
        self.embeddings = nn.Embedding(len(self.alphabet), emb_size)
        self.pos_embeddings = nn.Embedding(max_word_size, emb_size)
        self.device = self.embeddings.device

class Word_Encoder(nn.Module):
    def __init__(self, alphabet, emb_size, max_word_size=256):
        super().__init__()
        self.alphabet = list(alphabet) + ["<pad>", "<stress>", "<unk>"]
        self.emb_size = emb_size
        self.max_word_size = max_word_size
        self.embeddings = nn.Embedding(len(self.alphabet), emb_size)
        self.pos_embeddings = nn.Embedding(max_word_size, emb_size)

        self.get_idx = {char: idx for idx, char in enumerate(self.alphabet)}
        self.pad_idx = self.get_idx["<pad>"]
        self.stress_idx = self.get_idx["<stress>"]
        self.unk_idx = self.get_idx["<unk>"]
        self.device = self.embeddings.weight.device

    def tokenize(self, text):
        if isinstance(text, str):
            text = [text]
        tokenized = []
        for word in text:
            word_idxs = []
            i = 0
            n = len(word)
            while i < n:
                if word[i] == "<" and i + 8 < n and word[i:i+8] == "<stress>":
                  word_idxs.append(self.stress_idx)
                  i += 8
                else:
                    char = word[i]
                    if char in self.get_idx:
                        word_idxs.append(self.get_idx[char])
                    else:
                        word_idxs.append(self.unk_idx)
                    i += 1

            tokenized.append(word_idxs)
        max_len = max(len(word) for word in tokenized)
        padded = []
        for word in tokenized:
            padded_word = word
            if len(word) < max_len:
                padded_word += [self.pad_idx] * (max_len - len(word))
            padded.append(padded_word)

        return torch.tensor(padded, dtype=torch.long, device=self.device)

    def forward(self, x): # Не забыть проверить работу с батчами
        self.device = x.device
        batch, n = x.shape
        pos = torch.arange(n, device=self.device).unsquezze(0).expand(batch, n)
        x = self.embeddings(x) + self.pos_embeddings(pos)
        return x

class Noise_Encoder(nn.Module):
    def __init__(self, emb_size, timestamps = 1000):
        super().__init__()
        self.embeddings = nn.Embedding(timestamps, emb_size)
    def forward(self, x): # ✓ Не забыть проверить работу с батчами
        self.device = x.device
        return self.embeddings(x)
class Time_Encoder(nn.Module): # ✓ до 10 секунд
    def __init__(self, in_channels, out_channels, max_time_size=1024,): # Посмотреть максимальный размер по x, поставить на 20-50% больше
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, (3, 7), (1, 3), (1, 3))
        self.pos_embs =  nn.Embedding(max_time_size, out_channels)
        self.max_time_size = max_time_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv2 = nn.Conv2d(in_channels, out_channels, (3, 7), (1, 3), (1, 3))
        self.conv3 = nn.Conv2d(out_channels, out_channels, (3, 7), (1, 3), (1, 3))
        self.act = nn.Tanh()
    def forward(self, image):
        x = self.conv1(image)
        b, c, h, w = x.shape
        y = self.conv2(image)
        time = torch.arange(w).expand([b, w])
        pos = self.pos_embs(time) # [b, w, out]
        pos = torch.permute(pos, [0, 2, 1]).unsqueeze(1) # b, 1, c, w
        x = torch.permute(x, [0, 2, 1, 3]) # b, h, c, w
        pos = pos.expand_as(x)
        x = x + pos # [b, h, c, w]+[b, h, c, w]
        y = self.act(y)
        x = torch.permute(x, [0, 2, 1, 3])
        z = self.conv3(x)+y
        return z

In [134]:
class TTS_diffusion(nn.Module):
    def __init__(self, input_channels = 1, hidden_dims = 32, alphabet = "абвгдеёжзийклмнопрстуфхцчшщъыьэюя", emb_size_word = 128, emb_size_noise = 64,
                       noise_steps = 100, max_word_size = 256, max_time_size = 2048):
        super().__init__()
        self.input_scaler = nn.Sequential(
                nn.Conv2d(in_channels=input_channels, out_channels=hidden_dims, kernel_size=1), nn.Tanh()
            )
        self.precode = nn.Sequential(
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
                nn.Conv2d(hidden_dims, hidden_dims, 3, 2, 1), nn.BatchNorm2d(hidden_dims), nn.Tanh(),
                nn.Conv2d(hidden_dims, hidden_dims, 3, 2, 1), nn.BatchNorm2d(hidden_dims), nn.Tanh(),
                nn.ConvTranspose2d(hidden_dims, hidden_dims, 3, 2, 1, output_padding=1), nn.BatchNorm2d(hidden_dims), nn.Tanh(),
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
                nn.ConvTranspose2d(hidden_dims, hidden_dims, 3, 2, 1, output_padding=1), nn.BatchNorm2d(hidden_dims), nn.Tanh(),
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
            ) # ✓ что бы сделать размер кратным 4
        self.time_enc = Time_Encoder(hidden_dims, hidden_dims, max_time_size) # ✓ args сюда нужно пихнуть число каналов после precode + 
                                                                              # ✓ сколько хотим вернуть (лучше чуть больше) + максимальный размер картинки по X после MelSpec или как там оно
        self.post_time =  nn.Sequential(
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
        )
        self.word_enc = Word_Encoder(alphabet, emb_size_word, max_word_size) # ✓ args сюда нужно пихнуть алфавит, размер эмбеда (emb_size_word) + длину слова

        self.noise_enc = Noise_Encoder(emb_size_noise, noise_steps) # ✓ Пихнуть сюда число шагов в нойз шедулере и размер эмбеда (emb_size_noise)
        self.patch_img1_size = 16
        self.patch_img1 = PatchImage(self.patch_img1_size) # ✓ можно пробовать другие размеры, лучше больше 4
        self.atten_noise1 = OurAttentionLayer(self.patch_img1_size, hidden_dims, hidden_dims, emb_size_noise, True)
        self.atten_word1 = OurAttentionLayer(self.patch_img1_size, hidden_dims, hidden_dims, emb_size_word, True)
        self.atten_word2 = OurAttentionLayer(self.patch_img1_size, hidden_dims, hidden_dims, emb_size_word, True)
        self.satten1 = OurAttentionLayer(self.patch_img1_size, hidden_dims, hidden_dims, 1, False)
        self.unpatch_img1 = PatchImage(self.patch_img1_size, True)
        self.main_block1 = nn.Sequential(
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
                nn.Conv2d(hidden_dims, hidden_dims, 3, 2, 1), nn.BatchNorm2d(hidden_dims), nn.Tanh(),
                nn.Conv2d(hidden_dims, hidden_dims, 3, 2, 1), nn.BatchNorm2d(hidden_dims), nn.Tanh(),
            ) # ✓ сжимает в 4 раза
        
        # ✓ перед применением при помощи interpolate сделать картинку кратной sqrt(patch_size) по H и W (или использовать shuffle layers + конвы)
        self.patch_img2_size = 16
        self.patch_img2 = PatchImage(self.patch_img2_size) # ✓ можно пробовать другие размеры
        self.atten_word3 = OurAttentionLayer(self.patch_img2_size, hidden_dims, hidden_dims, emb_size_word, True)
        self.satten2 = OurAttentionLayer(self.patch_img2_size, hidden_dims, hidden_dims, 1, False)
        self.atten_word4 = OurAttentionLayer(self.patch_img2_size, hidden_dims, hidden_dims, emb_size_word, True)
        self.unpatch_img2 = PatchImage(self.patch_img2_size, True)
        self.main_block2 = nn.Sequential(
                nn.ConvTranspose2d(hidden_dims, hidden_dims, 3, 2, 1, output_padding=1), nn.BatchNorm2d(hidden_dims), nn.Tanh(),
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
                nn.ConvTranspose2d(hidden_dims, hidden_dims, 3, 2, 1, output_padding=1), nn.BatchNorm2d(hidden_dims), nn.Tanh(),
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
        )
        self.patch_img3_size = 16
        self.patch_img3 = PatchImage(self.patch_img3_size) # ✓ можно пробовать другие размеры
        self.atten_noise2 = OurAttentionLayer(self.patch_img3_size, hidden_dims, hidden_dims, emb_size_noise, True)
        self.atten_word5 = OurAttentionLayer(self.patch_img3_size, hidden_dims, hidden_dims, emb_size_word, True)
        self.atten_word6 = OurAttentionLayer(self.patch_img3_size, hidden_dims, hidden_dims, emb_size_word, True)
        self.satten3 = OurAttentionLayer(self.patch_img3_size, hidden_dims, hidden_dims, 1, False)
        self.unpatch_img3 = PatchImage(self.patch_img3_size, True)
        self.main_block3 = nn.Sequential(
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
                nn.Conv2d(hidden_dims, hidden_dims, 3, 1, 1), nn.Tanh(), nn.BatchNorm2d(hidden_dims),
        )
        self.output_scaler = nn.Sequential(
                nn.Conv2d(in_channels=hidden_dims, out_channels=input_channels, kernel_size=1), nn.Tanh(),
        )
        self.out_scaler_conv1 = nn.Conv2d(in_channels=hidden_dims, out_channels=input_channels, kernel_size=1)
        self.out_scaler_conv2 = nn.Conv2d(in_channels=input_channels, out_channels=input_channels, kernel_size=1)
    def drop(self, layer, x, *args):
        if self.training and torch.rand(1).item() < 0.1:
            return x
        return layer(x, *args)
    def forward(self, x, text, noise): # ✓ в трейне иногда дропать каждый (делать torch.rand) с вероятностью 1/10
        x = self.input_scaler(x)
        words = self.word_enc(text)
        sh = self.noise_enc(noise)
        x = self.time_enc(x)
        x = self.post_time(x)
        x = self.precode(x)
        
        x = self.resize_to_square(x, self.patch_img1_size)
        x = self.patch_img1(x)
        x = self.drop(self.atten_noise1, x, sh) # ✓ p = 0.1
        x = self.atten_word1(x, words)
        x = self.atten_word2(x, words)
        x = self.satten1(x)
        x = self.unpatch_img1(x)
        x = self.main_block1(x)

        x = self.resize_to_square(x, self.patch_img2_size)
        x = self.patch_img2(x)
        x = self.atten_word3(x, words)
        x = self.atten_word4(x, words)
        x = self.satten2(x)
        x = self.unpatch_img2(x)
        x = self.main_block2(x)
        # ✓ как-то напихать все что есть
        x = self.resize_to_square(x, self.patch_img3_size)
        x = self.patch_img3(x)
        x = self.drop(self.atten_noise2, x, sh) # ✓ p = 0.1
        x = self.atten_word5(x, words)
        x = self.atten_word6(x, words)
        x = self.satten3(x)
        x = self.unpatch_img3(x)
        x = self.main_block3(x)
        
        y = self.out_scaler_conv1(x)
        x = self.output_scaler(x)
        x = x * y
        x = self.out_scaler_conv2(x)
        return x
    def resize_to_square(self, x, patch_size):
        n = int(math.sqrt(patch_size))
        h1, w1 = x.shape[-2:]
        h2, w2 = ((h1 + n - 1) // n) * n, ((w1 + n - 1) // n) * n
        if h2 != h1 or w2 != w1:
            x = torch.nn.functional.interpolate(x, size=(h2, w2), mode="bilinear", align_corners=False)
        return x

# Тренер

In [180]:
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR

class AvegereMeter:
    def __init__(self,):
        self.arr = []
    def __call__(self, item, n=1):
        if n<=1:
            self.arr.extend([item])
        else:
            self.arr.extend([item]*n)
    def __str__(self,) -> str:
        return str(np.mean(np.array(self.arr)))
    def zero(self,):
        self.arr=[]

class TTS_Trainer:
    def __init__(self, model, train_dataloader, val_dataloader, vae, epochs=10, ):
        self.model = model
        self.vae = vae
        self.device = next(model.parameters()).device
        self.tdl = train_dataloader
        self.vdl = val_dataloader
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-4, weight_decay=1e-2)
        self.scaler = GradScaler(device="cuda")
        self.l2_loss = nn.MSELoss(reduction='none')
        self.loss_meter = AvegereMeter()
        self.epochs = epochs
        self.losses = []
        self.noise_sched = NoiseScheduler(100, epochs)
        self.noise_sched.alpha = self.noise_sched.alpha.to(self.device)
        self.noise_sched.alpha_hat = self.noise_sched.alpha_hat.to(self.device)

    def draw_diffusion(self, S=5, epoch=0):
        fig, axes = plt.subplots(11, S+1, figsize=(S+1, 11))
        noise = torch.randn([1, 32, 7, 7], device=self.device)
        transformed_weight = [noise.clone() for _ in range(11)]
        for i in range(0, 11):
            for j in range(S+1):

                if j>0:
                    #noise[0][0].numpy()+
                    idx = int((j/S)*99)
                    k = self.noise_sched(idx).item()
                    pred = self.model(transformed_weight[i], i<10, torch.tensor([*[0]*i,1,*[0]*(9-i), idx], dtype=torch.float32, device=self.device))
                    transformed_weight[i] = self.noise_sched.restore_image(transformed_weight[i], pred, idx, 1e-5)

                decoded = self.vae.decode(transformed_weight[i])
                img = decoded.squeeze(0).squeeze(0).cpu().detach().numpy()
                axes[i, j].imshow(img)
                axes[i, j].axis('off')
                if i==0:
                    axes[i, j].set_title(f'{j/S}')
        os.makedirs('./train/dl1/', exist_ok=True)
        plt.savefig(f'./train/dl1/diff{S}_{epoch}.png')
        plt.close(fig)

    def train_loop(self, KLDk=0.05, epoch=0):
        self.model.train()
        self.loss_meter.zero()
        pbar = tqdm(self.tdl, desc = 'train')
        cnt = 0
        self.vae.train()
        for mel, _ in pbar:
            mel = mel.to(self.device).float()
            b, c, h, w = mel.shape
            _, mu, logvar = self.vae.sample(mel)   # mu, logvar — [B,16,H',W']
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            z = mu + eps * std 
            #mel = z
            with torch.autocast('cuda', dtype=torch.bfloat16):
                idx = self.noise_sched.get_idx(epoch, b).to(self.device)
                model_input, noise = self.noise_sched.corrupt_image(z, idx)
                model_input = model_input.to(self.device)
                noise = noise.to(self.device)
                text_ids = torch.zeros(b, 1, dtype=torch.long, device=self.device)
                cond_onehot = nn.functional.one_hot(cond, num_classes=10).float().to(self.device)
                condition = torch.cat([cond_onehot, idx.view(-1,1)], dim=1)

                output = self.model(model_input, text_ids, idx)
        
                noise_loss = self.l2_loss(output, noise).mean()
                kld_loss = self.vae.KLD_loss(mu, logvar)
                loss = noise_loss + KLDk * kld_loss
        
            self.optimizer.zero_grad()
            self.losses.append(loss.item())
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
            cur_loss = loss.detach().cpu().item()
            pbar.desc = f"train, loss={cur_loss:.3f}"
            self.loss_meter(cur_loss, mel.size(0))
        print("Loss = "+self.loss_meter.__str__())

    def save_loss(self, filepath):
        with open(filepath, "w") as f:
            f.write(str(self.loss_meter))

    def save_image(self, audio, output, iter = 0):
        input_tensor = audio.cpu().detach().squeeze(0)   # shape [1, H, W]
        output_tensor = output.cpu().detach().squeeze(0)  # shape [1, H, W]

        # Преобразование в PIL-изображение
        transform = transforms.ToPILImage('RGB')
        rgb_input = torch.cat([input_tensor,  input_tensor,  input_tensor ], dim=0)
        rgb_output = torch.cat([output_tensor, output_tensor, output_tensor], dim=0)
        input_image = transform(rgb_input)
        output_image = transform(rgb_output)

        os.makedirs('./train', exist_ok=True)
        # Сохранение изображений
        input_image.save(f'./train/input_{iter}.png')
        output_image.save(f'./train/output_{iter}.png')

    def val_loop(self, iteration=0):
        self.model.eval()
        batch = next(iter(self.vdl))
        if isinstance(batch, (list, tuple)):
            mel = batch[0]                  
        else:
            mel = batch
        mel = mel.to(self.device, dtype=torch.float32)
        with torch.no_grad():
            z0 = self.vae.sample(mel[:1])[0]            # [1,C,H,W]
        cond_onehot = torch.zeros(1, 10, device=self.device)  # [1,10]
        t_start = torch.tensor([[float(self.noise_sched.steps-1)]], device=self.device)
        condition  = torch.cat([cond_onehot, t_start], dim=1)  # [1,11]
        v = torch.randn_like(z0)        # x_T
        T = self.noise_sched.steps
        snap_ts = {T-1, int(0.7*T), int(0.4*T), 0}
        snapshots = []
        for t in range(T-1, -1, -1):
            eps_pred = self.model(v, cond_onehot.argmax(dim=1), t)
            v = self.noise_sched.restore_image(v, eps_pred, t, sigma=0.0)
            if t in snap_ts:
                a_hat_t = self.noise_sched.alpha_hat[t]
                x0_pred = (v - torch.sqrt(1-a_hat_t)*eps_pred) / torch.sqrt(a_hat_t)
                img = self.vae.decode(x0_pred.float()).detach()
                snapshots.append((t, img.squeeze().cpu().numpy()))
        cols = len(snapshots)
        fig, axes = plt.subplots(1, cols, figsize=(cols*2, 2), dpi=150)
        for i, (tt, im) in enumerate(snapshots):
            axes[i].imshow(im)
            axes[i].set_title(f"t={tt}", fontsize=8)
            axes[i].axis("off")
        plt.tight_layout()
        plt.show()
        plt.close(fig)

In [181]:
mel_spec = torchaudio.transforms.MelSpectrogram(
    sample_rate = 16000,
    n_fft = 800,
    hop_length = 200,     
    win_length  = 800,
    n_mels = 80,
)

def wav_to_mel(wav):
    # wav: (1, L)
    spec = mel_spec(wav)      
    spec = torch.log(spec + 1e-6)
    T = spec.shape[-1]
    if T < 80:                    
        spec = F.pad(spec, (0, 80 - T))
    elif T > 80:                  
        spec = spec[..., :80]
    return spec  

class AudioDataset(Dataset):
    def __init__(self, csv_file, audio_dir, target_sr=16000, length_sec=None, transform=None):
        self.table = pd.read_csv(csv_file)
        self.audio_dir = Path(audio_dir)
        self.sr = target_sr
        self.length = int(target_sr * length_sec) if length_sec else None
        self.transform = transform

    def load_wav(self, path):
        wav, sr = torchaudio.load(path)            
        if wav.shape[0] > 1:        
            wav = wav.mean(dim=0, keepdim=True)
        if sr != self.sr:           
            wav = torchaudio.functional.resample(wav, sr, self.sr)
        return wav

    def pad_trim(self, wav):
        if self.length is None:
            return wav
        cur = wav.shape[-1]
        if cur > self.length:
            wav = wav[..., : self.length]
        elif cur < self.length:
            wav = pad(wav, (0, self.length - cur))
        return wav

    def __getitem__(self, idx):
        rel = self.table.iloc[idx]["file"]
        wav = self.load_wav(self.audio_dir / rel)
        wav = self.pad_trim(wav).float()
        mel = wav_to_mel(wav) 

        if self.transform:
            mel = self.transform(mel)

        return mel, 0         

    def __len__(self):
        return len(self.table)


def audio_collate(batch):
    mels, labels = zip(*batch)           
    return torch.stack(mels), torch.tensor(labels)                   


from torch.utils.data import DataLoader, SubsetRandomSampler

def build_dataloader(cfg, split, transform=None, workers=4, limit=10_000):
    d  = cfg["dataset"][split]
    ds = AudioDataset(
        d["table"], d["data"],
        cfg["vae"]["freq"], cfg["vae"]["lenght"],
        transform,
    )
    if limit and limit < len(ds):
        idx = np.random.choice(len(ds), limit, replace=False)
        ds = torch.utils.data.Subset(ds, idx)    
    sampler = None

    return DataLoader(
        ds,
        batch_size = cfg[split]["batch_size"],
        shuffle = (split == "train"),
        sampler = sampler,
        num_workers = workers,
        pin_memory = cfg[split]["pin_memory"],
        collate_fn = audio_collate,
    )


In [182]:
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
torch.set_float32_matmul_precision("high")

vae = VAE_Audio().to(device)
params_vae = torch.load(str(Path.home() / "Downloads" / "VaeAudio_V4_noise.pt"), map_location=device, weights_only=True)
vae.load_state_dict(params_vae)
vae.eval()


unet = MNIST_diffusion(input_channels=32, conditioning_length=10, timestamp_length=1, hidden_dims=128, mixin_dims=128).to(device)
with torch.no_grad():
    dummy = torch.zeros(1, 32, 8, 8, device=device)
    _ = unet(dummy, mix=False, c=None, n=1)

state = torch.load(Path.home() / "Downloads" / "Unet_diffusion_v6.pt", map_location=device, weights_only=True)
unet.load_state_dict(state)
unet.eval()

tts_model = TTS_diffusion(
    input_channels=32,
    hidden_dims=128,
    alphabet=alphabet,
    emb_size_word=emb_size_word,
    emb_size_noise=emb_size_noise,
    noise_steps=noise_steps,
    max_word_size=max_word_size,
    max_time_size=2048,
).to(device)

train_csv = config['dataset']['train']['table']
train_wav_dir = config['dataset']['train']['data']
val_csv = config['dataset']['val']['table']
val_wav_dir = config['dataset']['val']['data']

train_dataloader = build_dataloader(config, "train", workers=0, limit=2500)
val_dataloader = build_dataloader(config, "val", workers=0, limit=None)

all_epochs = 10

with torch.no_grad():
    mel_zero = torch.zeros(1, 1, 80, 80, device=device)
    z0, mu0, logvar0 = vae.sample(mel_zero)
    C_lat, H_lat, W_lat = z0.shape[1:]
    dummy_z = torch.zeros(1, C_lat, H_lat, W_lat, device=device)
    fake_text_ids = torch.zeros(1, 1, dtype=torch.long, device=device) 
    fake_text_lens = torch.ones(1, dtype=torch.long, device=device)    
    fake_t = torch.zeros(1, dtype=torch.long, device=device)
    _ = unet(dummy_z, fake_text_ids, fake_text_lens, fake_t)

In [183]:
trainer = TTS_Trainer(
    model=tts_model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    vae=vae,
    epochs=all_epochs
)
for epoch in range(all_epochs):
    print(f"\n Epoch: {epoch + 1}/{all_epochs}")
    trainer.train_loop()
    trainer.val_loop(iteration=epoch)

torch.save(tts_model.state_dict(), "TTSModelParameters_v6.pt")


 Epoch: 1/10


train:   0%|          | 0/157 [00:00<?, ?it/s]

NameError: name 'text_ids' is not defined

In [172]:
import torch
from IPython.display import Audio
import torchaudio
from pathlib import Path

# 1) Создаём энкодеры и расписатель шума
#   Используйте те же гиперпараметры, что и при обучении!
alphabet      = "абвгдеёжзийклмнопрстуфхцчшщъыьэюя"
emb_size_word = 128
max_word_size = 256
word_enc      = Word_Encoder(alphabet, emb_size_word, max_word_size).to(device)

emb_size_noise= 64
noise_steps   = 1000
noise_enc     = Noise_Encoder(emb_size_noise, noise_steps).to(device)

noise_sched   = NoiseScheduler(noise_steps, all_epochs)
noise_sched.alpha     = noise_sched.alpha.to(device)
noise_sched.alpha_hat = noise_sched.alpha_hat.to(device)

# 2) Вспомогательная функция для токенизации текста
def text_to_tensor(text: str):
    # метод tokenize уже возвращает тензор [1, L]
    return word_enc.tokenize(text).to(device)

# 3) Обратная диффузия (немного адаптированная под ваши классы)
@torch.no_grad()
def sample_mel(model, text_ids):
    # text_ids: [1, L]
    # 3.1 получаем эмбед текст+позиции
    text_emb = word_enc(text_ids)             # [1, L, emb_size]

    # 3.2 стартовый шум
    # C_lat = скрытые каналы после VAE decoder, H_lat/W_lat — spatial
    dummy = torch.zeros(1, 1, 80, 80, device=device)
    C_lat, H_lat, W_lat = vae.sample(dummy)[0].shape[1:]
    v = torch.randn(1, C_lat, H_lat, W_lat, device=device)

    # 3.3 обратный проход
    for t in range(noise_steps-1, -1, -1):
        # формируем conditioning: [text_emb_flat | t]
        # ваш ConditionMixingLayer принимает сначала x, потом c
        t_tensor = torch.tensor([[t]], device=device)
        cond = torch.cat([text_emb.mean(dim=1), t_tensor.float()], dim=1)  # [1, emb_size+1]

        eps_pred = model(v, mix=True, c=cond, n=1)
        v = noise_sched.restore_image(v, eps_pred, t, sigma=0.0)

    # 3.4 декодим через VAE
    mel = vae.decode(v)      # [1,1,80,80]
    return torch.clamp(mel, min=0)

# 4) Mel → waveform через Griffin-Lim
def mel_to_waveform(mel):
    # mel: [1,1,80,80]
    mel = mel.squeeze(0)  # [1,80,80]
    inv_mel = torchaudio.functional.inverse_mel_scale(
        mel,
        sample_rate=16000,
        n_stft=800//2+1,
        f_min=0,
        f_max=8000
    )
    wav = torchaudio.functional.griffinlim(
        inv_mel,
        n_fft=800,
        hop_length=200,
        win_length=800,
        n_iter=60
    )
    return wav

# ──────────────────────────────────────────────────────────────────────────────
# 5) Пример запуска инференса
model = TTS_diffusion(
    input_channels   = 32,
    hidden_dims      = 128,
    alphabet         = alphabet,        # как при обучении
    emb_size_word    = emb_size_word,   # 128
    emb_size_noise   = emb_size_noise,  # 64
    noise_steps      = noise_steps,     # 1000
    max_word_size    = max_word_size,   # 256
    max_time_size    = 2048,
).to(device)

ckpt = torch.load("TTSModelParameters_v6.pt", map_location="cuda", weights_only=True)
model.load_state_dict(ckpt, strict=False)
model.eval()
text      = "привет, как дела?"
text_ids  = text_to_tensor(text)         # [1, L]
mel_pred  = sample_mel(model, text_ids)  # [1,1,80,80]
wav_pred  = mel_to_waveform(mel_pred)    # [1, L_wave]

# 6) Слушаем результат в ноутбуке
Audio(wav_pred.cpu().squeeze(0).numpy(), rate=16000)


RuntimeError: Error(s) in loading state_dict for TTS_diffusion:
	size mismatch for input_scaler.0.weight: copying a param with shape torch.Size([128, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 128, 1, 1]).
	size mismatch for input_scaler.0.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for precode.0.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 32, 3, 3]).
	size mismatch for precode.0.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for precode.2.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for precode.2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for precode.2.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for precode.2.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for precode.3.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 32, 3, 3]).
	size mismatch for precode.3.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for output_scaler.0.weight: copying a param with shape torch.Size([128, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 32, 1, 1]).

In [170]:
ckpt = torch.load("TTSModelParameters_v6.pt", map_location="cpu")
print(list(ckpt.keys())[:20])


['input_scaler.0.weight', 'input_scaler.0.bias', 'output_scaler.0.weight', 'output_scaler.0.bias', 'output_scaler.2.weight', 'output_scaler.2.bias', 'precode.0.weight', 'precode.0.bias', 'precode.2.weight', 'precode.2.bias', 'precode.2.running_mean', 'precode.2.running_var', 'precode.2.num_batches_tracked', 'precode.3.weight', 'precode.3.bias', 'precode.5.weight', 'precode.5.bias', 'precode.5.running_mean', 'precode.5.running_var', 'precode.5.num_batches_tracked']


  ckpt = torch.load("TTSModelParameters_v6.pt", map_location="cpu")
