## 건드리지 않기

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LSTM(nn.Module):
    def __init__(self, embed_size, dim, num_layers, dropout, residual_embeddings=True):
        super(LSTM, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        self.rnn_dim = dim // 2
        self.linear = nn.Linear(dim + embed_size, dim)
        self.rnn = nn.LSTM(embed_size, self.rnn_dim, num_layers=num_layers, dropout=dropout,
                           bidirectional=True, batch_first=True)
        self.residual_embeddings = residual_embeddings
        self.init_hidden = nn.Parameter(nn.init.xavier_uniform_(torch.empty(2 * 2 * num_layers, self.rnn_dim)))
        self.num_layers = num_layers

    def forward(self, inputs):
        batch = inputs.size(0)
        h0 = self.init_hidden[:2 * self.num_layers].unsqueeze(1).expand(2 * self.num_layers,
                                                                        batch, self.rnn_dim).contiguous()
        c0 = self.init_hidden[2 * self.num_layers:].unsqueeze(1).expand(2 * self.num_layers,
                                                                        batch, self.rnn_dim).contiguous()

        print("LSTM inputs : ", inputs.shape)
        outputs, hidden_t = self.rnn(inputs, (h0, c0))

        if self.residual_embeddings:
            outputs = torch.cat([inputs, outputs], dim=-1)
        outputs = self.linear(self.dropout(outputs))

        return F.normalize(outputs, dim=-1)

In [2]:
import math
import torch
import torch.nn as nn
import sys

class DenseCoAttn(nn.Module):
	def __init__(self, dim1, dim2, dropout): #dim1, dim2 = 512, 512
		super(DenseCoAttn, self).__init__()
		dim = dim1 + dim2
		self.dropouts = nn.ModuleList([nn.Dropout(p=dropout) for _ in range(2)])
		self.query_linear = nn.Linear(dim, dim)
		self.key1_linear = nn.Linear(16, 16)
		self.key2_linear = nn.Linear(16, 16)
		self.value1_linear = nn.Linear(dim1, dim1)
		self.value2_linear = nn.Linear(dim2, dim2)
		self.relu = nn.ReLU()

	def forward(self, value1, value2):
		print("DenseCoAttn input value1(video) : ", value1.shape) # 16, 16, 512
		print("DenseCoAttn input value2(audio) : ", value2.shape)
		joint = torch.cat((value1, value2), dim=-1)
		# audio  audio*W*joint
		joint = self.query_linear(joint)
		print("DenseCoAttn joint representation : ", joint.shape)
		key1 = self.key1_linear(value1.transpose(1, 2)) # X_v^T
		key2 = self.key2_linear(value2.transpose(1, 2)) # X_a^T 
		print("DenseCoAttn X_v^T : ", key1.shape) # 16, 512, 16
		print("DenseCoAttn X_a^T : ", key2.shape)

		value1 = self.value1_linear(value1) # 16, 16, 512 (Can't understanding Layer)
		value2 = self.value2_linear(value2) # (Can't understanding Layer)
		print("DenseCoAttn value1 after value_linear : ", value1.shape)
		print("DenseCoAttn value2 after value_linear : ", value2.shape)

		weighted1, attn1 = self.qkv_attention(joint, key1, value1, dropout=self.dropouts[0])
		weighted2, attn2 = self.qkv_attention(joint, key2, value2, dropout=self.dropouts[1])
		print("DenseCoAttn weighted1 : ", weighted1.shape)
		print("DenseCoAttn weighted2 : ", weighted2.shape)

		return weighted1, weighted2

	def qkv_attention(self, query, key, value, dropout=None):
		d_k = query.size(-1)
		scores = torch.bmm(key, query) / math.sqrt(d_k)
		scores = torch.tanh(scores) # C_v, C_a
		if dropout:
			scores = dropout(scores)

		weighted = torch.tanh(torch.bmm(value, scores))
		return self.relu(weighted), scores # self.relu(weighted) == H_v, H_a


In [3]:
import torch.nn as nn

class NormalSubLayer(nn.Module):
    def __init__(self, dim1, dim2, dropout): # dim1, dim2 = 512, 512
        super(NormalSubLayer, self).__init__()
        self.dense_coattn = DenseCoAttn(dim1, dim2, dropout)
        self.linears = nn.ModuleList([
            nn.Sequential(
                nn.Linear(dim1 + dim2, dim1), # 1024, 512
                nn.ReLU(inplace=True),
                nn.Dropout(p=dropout),
            ),
            nn.Sequential(
                nn.Linear(dim1 + dim2, dim2),
                nn.ReLU(inplace=True),
                nn.Dropout(p=dropout),
            )
        ])

    def forward(self, data1, data2):
        weighted1, weighted2 = self.dense_coattn(data1, data2) # weighted1, weighted2 = 1024, 1024
        data1 = data1 + self.linears[0](weighted1) # X_att,v
        data2 = data2 + self.linears[1](weighted2) # X_att,a

        print("DCNLayer X_att,v : " , data1.shape)
        print("DCNLayer X_att,a : " , data2.shape)

        return data1, data2


class DCNLayer(nn.Module):
    def __init__(self, dim1, dim2, num_seq, dropout): # dim1, dim2 = 512, 512
        super(DCNLayer, self).__init__()
        self.dcn_layers = nn.ModuleList([NormalSubLayer(dim1, dim2, dropout) for _ in range(num_seq)])

    def forward(self, data1, data2):
        for dense_coattn in self.dcn_layers:
            data1, data2 = dense_coattn(data1, data2)

        return data1, data2

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys

class BottomUpExtract(nn.Module):
	def __init__(self, emed_dim, dim):
		super(BottomUpExtract, self).__init__()
		self.attn = PositionAttn(emed_dim, dim)

	def forward(self, video, audio):
		feat = self.attn(video, audio)

		return feat

# audio-guided attention
class PositionAttn(nn.Module):

	def __init__(self, embed_dim, dim):
		super(PositionAttn, self).__init__()
		self.affine_audio = nn.Linear(embed_dim, dim)
		self.affine_video = nn.Linear(512, dim)
		self.affine_v = nn.Linear(dim, 49, bias=False)
		self.affine_g = nn.Linear(dim, 49, bias=False)
		self.affine_h = nn.Linear(49, 1, bias=False)
		self.affine_feat = nn.Linear(512, dim)
		self.relu = nn.ReLU()

	def forward(self, video, audio):
		v_t = video.view(video.size(0) * video.size(1), -1, 512).contiguous()
		V = v_t

		# Audio-guided visual attention
		v_t = self.relu(self.affine_video(v_t))
		a_t = audio.view(-1, audio.size(-1))

		a_t = self.relu(self.affine_audio(a_t))

		content_v = self.affine_v(v_t) \
					+ self.affine_g(a_t).unsqueeze(2)

		z_t = self.affine_h((torch.tanh(content_v))).squeeze(2)

		alpha_t = F.softmax(z_t, dim=-1).view(z_t.size(0), -1, z_t.size(1))  # attention map

		c_t = torch.bmm(alpha_t, V).view(-1, 512)
		video_t = c_t.view(video.size(0), -1, 512)

		video_t = self.affine_feat(video_t)

		return video_t


In [14]:
from __future__ import absolute_import
from __future__ import division

from torch.nn import init
import torch
import math
from torch import nn
from torch.nn import functional as F
import sys

class LSTM_CAM(nn.Module):
    def __init__(self):
        super(LSTM_CAM, self).__init__()
        self.coattn = DCNLayer(512, 512, 2, 0.6)

        self.audio_extract = LSTM(512, 512, 2, 0.1, residual_embeddings=True)
        self.video_extract = LSTM(512, 512, 2, 0.1, residual_embeddings=True)

        self.video_attn = BottomUpExtract(512, 512)
        self.vregressor = nn.Sequential(nn.Linear(512, 128),
                                        nn.ReLU(inplace=True),
                                     nn.Dropout(0.6),
                                 nn.Linear(128, 1))

        # self.Joint = LSTM(1024, 512, 2, dropout=0, residual_embeddings=True)
        self.Joint = LSTM(2048, 512, 2, dropout=0, residual_embeddings=True)

        self.aregressor = nn.Sequential(nn.Linear(512, 128),
                                        nn.ReLU(inplace=True),
                                     nn.Dropout(0.6),
                                 nn.Linear(128, 1))

        self.init_weights()

    def init_weights(net, init_type='xavier', init_gain=1):

        if torch.cuda.is_available():
            net.cuda()

        def init_func(m):  # define the initialization function
            classname = m.__class__.__name__
            if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
                if init_type == 'normal':
                    init.uniform_(m.weight.data, 0.0, init_gain)
                elif init_type == 'xavier':
                    init.xavier_uniform_(m.weight.data, gain=init_gain)
                elif init_type == 'kaiming':
                    init.kaiming_uniform_(m.weight.data, a=0, mode='fan_in')
                elif init_type == 'orthogonal':
                    init.orthogonal_(m.weight.data, gain=init_gain)
                else:
                    raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
                if hasattr(m, 'bias') and m.bias is not None:
                    init.constant_(m.bias.data, 0.0)

        print('initialize network with %s' % init_type)
        net.apply(init_func)  # apply the initialization function <init_func>


    def forward(self, f1_norm, f2_norm):
        video = F.normalize(f2_norm, dim=-1)
        audio = F.normalize(f1_norm, dim=-1)

        print("LSTM_CAM input video : ", video.shape)
        print("LSTM_CAM input audio : ", audio.shape)
        
        # Tried with LSTMs also
        audio = self.audio_extract(audio)
        video = self.video_attn(video, audio)
        video = self.video_extract(video)
        print("LSTM_CAM after LSTM video : ", video.shape)
        print("LSTM_CAM after LSTM audio : ", audio.shape)
        
        video, audio = self.coattn(video, audio)
        print("LSTM_CAM after coattn video : ", video.shape)
        print("LSTM_CAM after coattn audio : ", audio.shape)

        video = F.pad(video, (0, 512))
        audio = F.pad(audio, (0, 512))

        print("LSTM_CAM after padding video : ", video.shape)
        print("LSTM_CAM after padding audio : ", audio.shape)        
        
        audiovisualfeatures = torch.cat((video, audio), -1)
        print("LSTM_CAM before padding audiovisualfeatures : ", audiovisualfeatures.shape)
        # audiovisualfeatures = F.pad(audiovisualfeatures, (0, 1024))
        
        # print("LSTM_CAM after padding audiovisualfeatures : ", audiovisualfeatures.shape)
        
        audiovisualfeatures = self.Joint(audiovisualfeatures)
        vouts = self.vregressor(audiovisualfeatures) #.transpose(0,1))
        aouts = self.aregressor(audiovisualfeatures) #.transpose(0,1))
        print("LSTM_CAM after Joint audiovisualfeatures : ", audiovisualfeatures.shape)
        print("LSTM_CAM vouts : ", vouts.shape)
        print("LSTM_CAM aouts : ", aouts.shape)

        print('--'*25)

        return vouts.squeeze(2), aouts.squeeze(2)  #final_aud_feat.transpose(1,2), final_vis_feat.transpose(1,2)

In [15]:
from models.tsav import TwoStreamAuralVisualModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = 'ABAW2020TNT/model2/TSAV_Sub4_544k.pth.tar' # path to the model
model = TwoStreamAuralVisualModel(num_channels=4)
saved_model = torch.load(model_path)
model.load_state_dict(saved_model['state_dict'])

new_first_layer = nn.Conv3d(in_channels=3,
					out_channels=model.video_model.r2plus1d.stem[0].out_channels,
					kernel_size=model.video_model.r2plus1d.stem[0].kernel_size,
					stride=model.video_model.r2plus1d.stem[0].stride,
					padding=model.video_model.r2plus1d.stem[0].padding,
					bias=False)

new_first_layer.weight.data = model.video_model.r2plus1d.stem[0].weight.data[:, 0:3]
model.video_model.r2plus1d.stem[0] = new_first_layer
model = nn.DataParallel(model)
model = model.to(device)


In [16]:
audiodata = torch.empty((16, 16, 1, 64, 104)).cuda()
visualdata = torch.empty((16, 16, 3, 8, 112, 112)).cuda()

with torch.cuda.amp.autocast():
    with torch.no_grad():
        visual_feats = torch.empty((16, 16, 25088), device = visualdata.device)
        aud_feats = torch.empty((16, 16, 512), device = visualdata.device)

        for i in range(16):
            aud_feat, visualfeat, _ = model(audiodata[i,:,:,:], visualdata[i, :, :, :,:,:])
            visual_feats[i,:,:] = visualfeat.view(16, -1)
            aud_feats[i,:,:] = aud_feat

cam = LSTM_CAM()
result = cam(aud_feats, visual_feats)

initialize network with xavier
LSTM_CAM input video :  torch.Size([16, 16, 25088])
LSTM_CAM input audio :  torch.Size([16, 16, 512])
LSTM inputs :  torch.Size([16, 16, 512])
LSTM inputs :  torch.Size([16, 16, 512])
LSTM_CAM after LSTM video :  torch.Size([16, 16, 512])
LSTM_CAM after LSTM audio :  torch.Size([16, 16, 512])
DenseCoAttn input value1(video) :  torch.Size([16, 16, 512])
DenseCoAttn input value2(audio) :  torch.Size([16, 16, 512])
DenseCoAttn joint representation :  torch.Size([16, 16, 1024])
DenseCoAttn X_v^T :  torch.Size([16, 512, 16])
DenseCoAttn X_a^T :  torch.Size([16, 512, 16])
DenseCoAttn value1 after value_linear :  torch.Size([16, 16, 512])
DenseCoAttn value2 after value_linear :  torch.Size([16, 16, 512])
DenseCoAttn weighted1 :  torch.Size([16, 16, 1024])
DenseCoAttn weighted2 :  torch.Size([16, 16, 1024])
DCNLayer X_att,v :  torch.Size([16, 16, 512])
DCNLayer X_att,a :  torch.Size([16, 16, 512])
LSTM_CAM after coattn video :  torch.Size([16, 16, 512])
LSTM_CAM 