In [1]:
#importar dependencias
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
import os

os.environ['TORCH_HOME'] = '../../pretrained_models'
import timm
from timm.models.layers import to_2tuple,trunc_normal_
import wget

In [2]:
# override the timm package to relax the input shape constraint.
class PatchEmbed(nn.Module):
    '''
    Embeding(?)
    Parametros: img_size = dimension del input 
    '''
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()

        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

In [8]:
class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, 
                 norm_layer=None, bias=True,drop=0.,use_conv=False):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)
        linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear

        self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
        self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

In [9]:
class Block(nn.Module):

    def __init__(
            self,
            dim,
            num_heads,
            mlp_ratio=4.,
            qkv_bias=False,
            qk_norm=False,
            proj_drop=0.,
            attn_drop=0.,
            init_values=None,
            drop_path=0.,
            act_layer=nn.GELU,
            norm_layer=nn.LayerNorm,
            mlp_layer=Mlp):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim,num_heads=num_heads,qkv_bias=qkv_bias,qk_norm=qk_norm,attn_drop=attn_drop,proj_drop=proj_drop, norm_layer=norm_layer)
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = mlp_layer(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=proj_drop,
        )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x


In [11]:
class ASTModel(nn.Module):
    """
    The AST model.
    :param label_dim: numero de clases. the label dimension, i.e., the number of total classes, it is 527 for AudioSet, 50 for ESC-50, and 35 for speechcommands v2-35
    :param fstride: division del patch en dim frecuencia. the stride of patch spliting on the frequency dimension, for 16*16 patchs, fstride=16 means no overlap, fstride=10 means overlap of 6
    :param tstride: division del patch en dim tiempo. the stride of patch spliting on the time dimension, for 16*16 patchs, tstride=16 means no overlap, tstride=10 means overlap of 6
    :param input_fdim: bins de frecuencia en entrada. the number of frequency bins of the input spectrogram
    :param input_tdim: frames de tiempo en entrada. the number of time frames of the input spectrogram
    :param imagenet_pretrain: bool imageNet pretrain. if use ImageNet pretrained model
    :param audioset_pretrain: bool audio e image pretrain. if use full AudioSet and ImageNet pretrained model
    :param model_size: dimensiones del ast. the model size of AST, should be in [tiny224, small224, base224, base384], base224 and base 384 are same model, but are trained differently during ImageNet pretraining.
    """
    def __init__(self, label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=True,
                 audioset_pretrain=False, model_size='base384'):

        super(ASTModel, self).__init__()
        #QT
        assert timm.__version__ == '0.4.5', 'Please use timm == 0.4.5, the code might not be compatible with newer versions.'
        
        # Modelo base  model_size == 'small224':
                #self.v = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=imagenet_pretrain)
        self.patch_embed = PatchEmbed()
        self.original_num_patches = self.patch_embed.num_patches
        self.oringal_hw = int(self.original_num_patches ** 0.5)
        
        embed_len = self.patch_embed.num_patches #QT(?)if no_embed_class else num_patches + self.num_prefix_tokens 
        self.embed_dim = self.patch_embed.embed_dim #(?)
        self.pos_embed = nn.Parameter(torch.zeros(1, self.original_num_patches + 2, self.embed_dim))
                        #QT nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
        
        
        self.original_embedding_dim = self.pos_embed.shape[2]
        #cabezas
        self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), 
                                          nn.Linear(self.original_embedding_dim, label_dim))
        # automatcially get the intermediate shape
        f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
        num_patches = f_dim * t_dim
        self.patch_embed.num_patches = num_patches
        # the linear projection layer
        new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
        self.patch_embed.proj = new_proj
        # if not use imagenet pretrained model, just randomly initialize a learnable positional embedding
        # TODO can use sinusoidal positional embedding instead
        new_pos_embed = nn.Parameter(torch.zeros(1, self.v.patch_embed.num_patches + 2, self.original_embedding_dim))
        self.pos_embed = new_pos_embed#bien doble definición
        trunc_normal_(self.v.pos_embed, std=.02)
        #self.cls_token = tf.Variable(initial_value=initial_value, trainable=True, name="cls")

        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))# if class_token else None #(?)
        # https://huggingface.co/spaces/Hila/RobustViT/blob/main/ViT/ViT_new.py
        self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
                        #QT nn.Parameter(torch.zeros(1, 1, embed_dim)) #if distilled else 
        pos_drop_rate = 0.0
        self.pos_drop = nn.Dropout(p=pos_drop_rate)
        num_heads=6
        mlp_ratio=4
        qkv_bias=True
        #qk_norm=False #pred
        #init_values=
        #proj_drop=
        #attn_drop=
        #drop_path=
        norm_layer=partial(nn.LayerNorm, eps=1e-6)
        #act_layer=
        #mlp_layer=
        #block_fn: Callable = Block,
         #QT mover a embeded(?)patch_size=16, embed_dim=384
        
        depth = 12
        self.blocks = nn.Sequential(*[
            Block(
                dim=self.embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                #qk_norm=qk_norm,
                #init_values=init_values,
                #proj_drop=proj_drop_rate,
                #attn_drop=attn_drop_rate,
                #drop_path=dpr[i],
                norm_layer=norm_layer,
                #act_layer=act_layer,
                #mlp_layer=mlp_layer,
            )
            for i in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)#norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
    def get_shape(self, fstride, tstride, input_fdim=128, input_tdim=1024):
        test_input = torch.randn(1, 1, input_fdim, input_tdim)
        test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
        test_out = test_proj(test_input)
        f_dim = test_out.shape[2]
        t_dim = test_out.shape[3]
        return f_dim, t_dim

    @autocast()
    def forward(self, x):
        """
        :param x: the input spectrogram, expected shape: (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
        :return: prediction
        """
        # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
        x = x.unsqueeze(1)
        x = x.transpose(2, 3)

        B = x.shape[0]
        x = self.patch_embed(x) #.forward(x)(?)
        
        cls_tokens = self.cls_token.expand(B, -1, -1)
        dist_token = self.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)
        #x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)#alternativa
        x = x + self.pos_embed
        x = self.pos_drop(x) 
        for blk in self.blocks:#aaaa
            x = blk(x)
        x = self.norm(x)
        x = (x[:, 0] + x[:, 1]) / 2

        x = self.mlp_head(x)
        return x


