In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from gat import GATLayer, GAT, GraphAttentionLayer
import torch
import torch.nn.functional as F

## GraphAttentionLayer

:::{note}
GraphAttentionLayer来自RawGAT-ST
:::

In [3]:
m = GraphAttentionLayer(in_dim=128, out_dim=64)
x = torch.randn(2, 60, 128)
m(x).shape

torch.Size([2, 60, 64])

## GATLayer and GAT

:::{note}
GATLayer and GAT 来自[gordicaleksa/pytorch-GAT: My implementation of the original GAT paper (Veličković et al.)](https://github.com/gordicaleksa/pytorch-GAT/tree/main)
:::

:::{warning}
对于GATLayer，需要注意的是：
1. 实际输出的feat大小为`num_out_features * num_of_head`！
2. edge_index的大小为`(2, N)`，表示$N$条边
3. GATLayer不能处理batch输入，因此必须将batch的大小从`(B, T, C) -> (B*T, C)`　
4. 输出的edge_inex和输入的edge_index是一样的，这是为了使得模型的输入和输出都是tuple，从而可以使用torch.Sequential
:::

In [4]:
edge_index = torch.tensor([[0, 1], [1, 2], [2, 3]]).T

m = GATLayer(num_in_features=128, num_out_features=128, num_of_heads=4)
x = torch.randn(3600, 128)
y, edge_index = m((x,edge_index))
print(y.shape)

torch.Size([3600, 512])


### GAT模型　

In [8]:
gat_config = {
    # GNNs, contrary to CNNs, are often shallow (it ultimately depends on the graph properties)
    "num_of_layers": 3,  # PPI has got 42% of nodes with all 0 features - that's why 3 layers are useful
    "num_heads_per_layer": [6, 6, 6],  # other values may give even better results from the reported ones
    "num_features_per_layer": [768, 128, 128, 1],  # the first number is actually input dim
    "add_skip_connection": True,  # skip connection is very important! (keep it otherwise micro-F1 is almost 0)
    "bias": True,  # bias doesn't matter that much
    "dropout": 0.0,  # dropout hurts the performance (best to keep it at 0)
}

In [9]:
model = GAT(**gat_config)

In [10]:
hidden_states = torch.randn(1200, 768)
edge_index = torch.randint(0, 1200, (2, 100) )
logit, edge_index = model((hidden_states, edge_index))
logit.shape

torch.Size([1200, 1])

# Wav2Vec

In [5]:
from transformers import Wav2Vec2ForCTC, WavLMForCTC
import torch

In [3]:
m = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You sho

In [6]:
x = torch.randn(2, 48000)
feat1 = m.wav2vec2.feature_extractor(x)

In [7]:
x = torch.randn(2, 48000)
feat1 = m.wav2vec2.feature_extractor(x).transpose(1, 2)
hidden_states, extract_features = m.wav2vec2.feature_projection(feat1) # extract_features is the layer norm of feat1

In [8]:
hidden_states.shape

torch.Size([2, 149, 768])

In [9]:
print(feat1.shape, extract_features.shape)

torch.Size([2, 149, 512]) torch.Size([2, 149, 512])


In [19]:
final_feat = m.wav2vec2.encoder(hidden_states)[0]

logits = m.lm_head(final_feat)
print(final_feat.shape, logits.shape)

torch.Size([2, 149, 768]) torch.Size([2, 149, 32])
