In [5]:
# !pip install einops, torchsummary

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
# import math
import numpy as np

# from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
# from torchsummary import summary


class Generator(nn.Module):
    def __init__(self, seq_len=150, patch_size=15, 
                 channels=3, #num_classes=9, 
                 latent_dim=100, embed_dim=10, 
                 depth=3, num_heads=5, 
                 forward_drop_rate=0.5, attn_drop_rate=0.5):
        super(Generator, self).__init__()
        self.channels = channels # признаковая размерность 
        self.latent_dim = latent_dim # размерность N шума на входе генератора
        self.seq_len = seq_len # W или последовательность временная
        self.embed_dim = embed_dim # пока не понятно, размерность с которой сопоставляется шум?
        self.patch_size = patch_size # число патчей, на которые делится последовательность W чтобы получить позиционное кодирование
        self.depth = depth # какой то гиперпараметр трансформера
        self.num_heads = num_heads # число голов внимания в блоках трансформера
        self.attn_drop_rate = attn_drop_rate # видимо дропаут атеншна
        self.forward_drop_rate = forward_drop_rate # видимо дропаут прямого прохода трансформера
        
        # num_classes num_heads
        
        self.l1 = nn.Linear(self.latent_dim, self.seq_len * self.embed_dim) # преобразует длину шума N в длину временной последовательности seq_len * embed_dim, отсюда embed_dim видимо параметр, какие числом вариантом мы можем представить каждый временной шаг
        self.pos_embed = nn.Parameter(torch.zeros(1, self.seq_len, self.embed_dim)) # видимо обучаемый позиционный кодинг
        self.blocks = Gen_TransformerEncoder(depth=self.depth, emb_size = self.embed_dim, num_heads=self.num_heads, drop_p=self.attn_drop_rate, forward_drop_p=self.forward_drop_rate) # непосредственно сам трансформер

        self.deconv = nn.Sequential(nn.Conv2d(self.embed_dim, self.channels, 1, 1, 0)) # выходной слой, который преобразует все что нагенерировали в нужную размерность, тоесть из embed_dim получаем число каналов, но тогда в этом случае embed_dim это и есть число признаков внутри трансформера

    def forward(self, z):
        x = self.l1(z).view(-1, self.seq_len, self.embed_dim)
        x = x + self.pos_embed
        H, W = 1, self.seq_len
        x = self.blocks(x)
        x = x.transpose(1, 2).unsqueeze(-2)
        output = self.deconv(x)
        output = output
        return output
    
class Gen_TransformerEncoderBlock(nn.Sequential): # строится один блок енкодера с num_heads голов внимания
    def __init__(self, emb_size, num_heads=5, drop_p=0.5, forward_expansion=4, forward_drop_p=0.5):
        super().__init__(ResidualAdd(nn.Sequential(nn.LayerNorm(emb_size), MultiHeadAttention(emb_size, num_heads, drop_p), nn.Dropout(drop_p))),
                         ResidualAdd(nn.Sequential(nn.LayerNorm(emb_size), FeedForwardBlock(emb_size, expansion=forward_expansion, drop_p=forward_drop_p), nn.Dropout(drop_p)))
                        )

class Gen_TransformerEncoder(nn.Sequential): # последовательно встраиваемое число, равное depth, блоков енкодера 
    def __init__(self, depth=8, **kwargs):
        super().__init__(*[Gen_TransformerEncoderBlock(**kwargs) for _ in range(depth)])       
             
class MultiHeadAttention(nn.Module): # блок внимания
    def __init__(self, emb_size, num_heads, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads) # ЕСЛИ МОЖНО надо бы заменить на чисто pytorch, чтобы уйти от библиотеки einops
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads) # ЕСЛИ МОЖНО надо бы заменить на чисто pytorch, чтобы уйти от библиотеки einops
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads) # ЕСЛИ МОЖНО надо бы заменить на чисто pytorch, чтобы уйти от библиотеки einops
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  # batch, num_heads, query_len, key_len # ЕСЛИ МОЖНО надо бы заменить на чисто pytorch, чтобы уйти от библиотеки einops
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        scaling = self.emb_size ** (1 / 2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        out = torch.einsum('bhal, bhlv -> bhav ', att, values) # ЕСЛИ МОЖНО надо бы заменить на чисто pytorch, чтобы уйти от библиотеки einops
        out = rearrange(out, "b h n d -> b n (h d)") # ЕСЛИ МОЖНО надо бы заменить на чисто pytorch, чтобы уйти от библиотеки einops
        out = self.projection(out)
        return out
    
class ResidualAdd(nn.Module): # резидуал блок, который строится в блоке трансформера
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x
     
class FeedForwardBlock(nn.Sequential): # сеть прямого распространения, которая строится в блоке транфсормера
    def __init__(self, emb_size, expansion, drop_p):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

        
class Dis_TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size=100,
                 num_heads=5,
                 drop_p=0.,
                 forward_expansion=4,
                 forward_drop_p=0.):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, num_heads, drop_p),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))


class Dis_TransformerEncoder(nn.Sequential):
    def __init__(self, depth=8, **kwargs):
        super().__init__(*[Dis_TransformerEncoderBlock(**kwargs) for _ in range(depth)])
        
        
class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size=100, n_classes=2):
        super().__init__()
        self.clshead = nn.Sequential(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, n_classes),
            # nn.Sigmoid()
        )

    def forward(self, x):
        out = self.clshead(x).squeeze(-1)
        return out

    
class PatchEmbedding_Linear(nn.Module):
    #what are the proper parameters set here?
    def __init__(self, in_channels = 21, patch_size = 16, emb_size = 100, seq_len = 1024):
        # self.patch_size = patch_size
        super().__init__()
        #change the conv2d parameters here
        self.projection = nn.Sequential(
            Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1 = 1, s2 = patch_size),
            nn.Linear(patch_size*in_channels, emb_size)
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.positions = nn.Parameter(torch.randn((seq_len // patch_size) + 1, emb_size))

    def forward(self, x: Tensor) -> Tensor:
        # x = x.unsqueeze(2)
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        #prepend the cls token to the input
        x = torch.cat([cls_tokens, x], dim=1)
        # position
        x += self.positions
        return x        
        
        
class DiscriminatorTTS(nn.Sequential):
    def __init__(self, 
                 in_channels=3, # по аналогии с генератором - признаковая размерность
                 patch_size=15, # по аналогии с генератором - число патчей для разбивки временной последовательности
                 emb_size=50, # скрытое состояние
                 seq_len=150, # длина временной последовательности на входе
                 depth=3, # глубина самого транфсормера
                 num_heads=5,
                 n_classes=1, # предсказывает вероятность real/fake
                 **kwargs):
        super().__init__(
            PatchEmbedding_Linear(in_channels, patch_size, emb_size, seq_len),
            Dis_TransformerEncoder(depth, emb_size=emb_size, num_heads=num_heads, drop_p=0.5, forward_drop_p=0.5, **kwargs),
            ClassificationHead(emb_size, n_classes)
        )
        

In [2]:
class PositionEmbeddingLearned(nn.Module):
    """
    Absolute pos embedding, learned.
    """
    def __init__(self, seq_len=27, num_pos_feats=512):
        super().__init__()
        self.embed = nn.Embedding(seq_len, num_pos_feats)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.uniform_(self.embed.weight)

    def forward(self, x):
        h, w = x.shape[1], x.shape[2]
        j = torch.arange(h, device=x.device)
        pos = self.embed(j)
        pos = pos.repeat(x.shape[0], 1, 1)
        return x + pos

class Discriminator(nn.Module):
    def __init__(self, device, input_size, src_seq_len, batch_first,  
                 n_encoder_layers, n_decoder_layers, n_heads, dim_val, 
                 dropout_encoder, dropout_decoder, dropout_pos_enc,
                 dim_feedforward_encoder, dim_feedforward_decoder,
                 output_size, pred_len):
        super(Discriminator, self).__init__() 
        self.output_size = output_size
        self.device = device
        self.input_size = input_size
        self.src_seq_len = src_seq_len - pred_len
        self.pred_len = pred_len

        self.encoder_input_layer = nn.Linear(in_features=input_size, out_features=dim_val)
        self.encoder_input_layer.to(self.device)
        self.decoder_input_layer = nn.Linear(in_features=input_size, out_features=dim_val)  
        self.decoder_input_layer.to(self.device)
        
        self.positional_encoding_layer = PositionEmbeddingLearned(seq_len=src_seq_len, 
                                                                  num_pos_feats=dim_val)
        self.positional_encoding_layer.to(self.device)

        encoder_layer = nn.TransformerEncoderLayer(d_model=dim_val, nhead=n_heads,
                                                   dim_feedforward=dim_feedforward_encoder,
                                                   dropout=dropout_encoder, 
                                                   batch_first=batch_first, 
                                                   activation='relu', 
                                                   norm_first=True
                                                   )
        self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, 
                                             num_layers=n_encoder_layers,
                                             norm=None
                                             )
        self.encoder.to(self.device)

        decoder_layer = nn.TransformerDecoderLayer(d_model=dim_val, nhead=n_heads,
                                                   dim_feedforward=dim_feedforward_decoder,
                                                   dropout=dropout_decoder, 
                                                   batch_first=batch_first, 
                                                   activation='relu', 
                                                   norm_first=True
                                                   )
        self.decoder = nn.TransformerDecoder(decoder_layer=decoder_layer, 
                                             num_layers=n_decoder_layers,
                                             norm=None
                                             )
        self.decoder.to(self.device)
        
        self.decoder_output_layer = nn.Sequential(torch.nn.Linear(in_features=dim_val, out_features=output_size),
                                                  # torch.nn.Sigmoid()
                                                  )
        
        # self.src_mask = torch.triu(torch.ones((self.pred_len + 1, self.src_seq_len), 
        #                                       device=self.device) * float('-inf'), diagonal=1).to(torch.bool)
        # self.tgt_mask = torch.triu(torch.ones((self.pred_len + 1, self.pred_len + 1), 
        #                                       device=self.device) * float('-inf'), diagonal=1).to(torch.bool)
    
    def forward(self, src):
        tgt = src[:, -1, :].clone().unsqueeze(1)
        # src = src[:, :, :]

        encoder_output = self.encoder_input_layer(src) 
        encoder_output = self.positional_encoding_layer(encoder_output) 
        encoder_output = self.encoder(encoder_output)

        decoder_output = self.decoder_input_layer(tgt)
        decoder_output = self.decoder(tgt=decoder_output, memory=encoder_output, tgt_mask=None, memory_mask=None)

        out = self.decoder_output_layer(decoder_output)
        
        return out


In [51]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
gen = Generator(seq_len = 471, 
                patch_size = 471, 
                channels = 9, 
                latent_dim = 100, 
                embed_dim = 256, 
                depth = 6, 
                num_heads = 32)
gen_numel = sum([i.numel() for i in gen.parameters()])
dis_tts = DiscriminatorTTS(in_channels = 9, 
                    patch_size = 157, 
                    emb_size = 384, 
                    seq_len = 471, 
                    depth = 3, 
                    num_heads = 8)
distts_numel = sum([i.numel() for i in dis_tts.parameters()])
dis = Discriminator(device, 
                    input_size=9, 
                    src_seq_len=471,
                    batch_first=True, 
                    n_encoder_layers=8, 
                    n_decoder_layers=8, 
                    n_heads=8, 
                    dim_val=64, 
                    dropout_encoder=0.2,
                    dropout_decoder=0.2, 
                    dropout_pos_enc=0.1,
                    dim_feedforward_encoder=1024, 
                    dim_feedforward_decoder=1024, 
                    output_size=1, 
                    pred_len=1)
dis_numel = sum([i.numel() for i in dis.parameters()])
print(f'G_tts {gen_numel}, D_tts {distts_numel}, D {dis_numel}, G_tts/D_tts {(gen_numel / distts_numel):.2f}, G_tts/D {(gen_numel / dis_numel):.2f}')

cpu
G_tts 17039625, D_tts 5869441, D 2550529, G_tts/D_tts 2.90, G_tts/D 6.68


In [31]:
%%time
# z = torch.FloatTensor(2, 100).normal_(0, 1)
z = torch.FloatTensor(np.random.normal(0, 1, (1, 100)))
g_out = gen(z)
print(g_out.shape)
print(g_out.squeeze(2).transpose(1,2).shape)
d_out = dis(g_out.squeeze(2).transpose(1,2)).view(-1)
print(d_out.shape)
print(d_out)

# real_seq = torch.FloatTensor(1, 54, 557).normal_(0, 1)
# print(real_seq.shape)
# d_out = dis(real_seq.transpose(1, 2).unsqueeze(2))
# print(d_out.shape)
# print(d_out)

torch.Size([1, 9, 1, 471])
torch.Size([1, 471, 9])
torch.Size([1])
tensor([0.0837], grad_fn=<ViewBackward0>)
CPU times: user 4.08 s, sys: 2.92 s, total: 7.01 s
Wall time: 989 ms


In [None]:
x = self.l1(z).view(-1, self.seq_len, self.embed_dim)
x = x + self.pos_embed
H, W = 1, self.seq_len
x = self.blocks(x)
print(x.shape)
x = x.transpose(1, 2).unsqueeze(-1)
# x = x.reshape(x.shape[0], x.shape[1], 1, x.shape[2])
# print(x.shape)
# x = x.permute(0, 3, 1, 2)
print(x.shape)
output = self.deconv(x)
print(output.shape)
output = output.view(-1, self.channels, H, W)
print(output.shape)
return output

In [32]:
count_params = 0
params_to_update = []

for param in dis.parameters():
    params_to_update.append(param)
print(len(params_to_update))

136
