In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LSTM(nn.Module):
    def __init__(self, embed_size, dim, num_layers, dropout, residual_embeddings=True):
        super(LSTM, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        self.rnn_dim = dim // 2
        self.linear = nn.Linear(dim + embed_size, dim)
        self.rnn = nn.LSTM(embed_size, self.rnn_dim, num_layers=num_layers, dropout=dropout,
                           bidirectional=True, batch_first=True)
        self.residual_embeddings = residual_embeddings
        self.init_hidden = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(2 * 2 * num_layers, self.rnn_dim)))
        self.num_layers = num_layers

    def forward(self, inputs):
        batch = inputs.size(0)
        h0 = self.init_hidden[:2 * self.num_layers].unsqueeze(1).expand(2 * self.num_layers,
                                                                        batch, self.rnn_dim).contiguous()
        c0 = self.init_hidden[2 * self.num_layers:].unsqueeze(1).expand(2 * self.num_layers,
                                                                        batch, self.rnn_dim).contiguous()

        print("LSTM inputs : ", inputs.shape)
        outputs, hidden_t = self.rnn(inputs, (h0, c0))

        if self.residual_embeddings:
            outputs = torch.cat([inputs, outputs], dim=-1)
        outputs = self.linear(self.dropout(outputs))

        return F.normalize(outputs, dim=-1)

In [2]:
import math
import torch
import torch.nn as nn
import sys

class DenseCoAttn(nn.Module):
	def __init__(self, dim1, dim2, dropout): #dim1, dim2 = 512, 512
		super(DenseCoAttn, self).__init__()
		dim = dim1 + dim2
		self.dropouts = nn.ModuleList([nn.Dropout(p=dropout) for _ in range(2)])
		self.query_linear = nn.Linear(dim, dim)
		self.key1_linear = nn.Linear(16, 16)
		self.key2_linear = nn.Linear(16, 16)
		self.value1_linear = nn.Linear(dim1, dim1)
		self.value2_linear = nn.Linear(dim2, dim2)
		self.relu = nn.ReLU()

	def forward(self, value1, value2):
		print("DenseCoAttn input value1(video) : ", value1.shape) # 16, 16, 512
		print("DenseCoAttn input value2(audio) : ", value2.shape)
		joint = torch.cat((value1, value2), dim=-1)
		# audio  audio*W*joint
		joint = self.query_linear(joint)
		print("DenseCoAttn joint representation : ", joint.shape)
		key1 = self.key1_linear(value1.transpose(1, 2)) # X_v^T
		key2 = self.key2_linear(value2.transpose(1, 2)) # X_a^T 
		print("DenseCoAttn X_v^T : ", key1.shape) # 16, 512, 16
		print("DenseCoAttn X_a^T : ", key2.shape)

		value1 = self.value1_linear(value1) # 16, 16, 512 (Can't understanding Layer)
		value2 = self.value2_linear(value2) # (Can't understanding Layer)
		print("DenseCoAttn value1 after value_linear : ", value1.shape)
		print("DenseCoAttn value2 after value_linear : ", value2.shape)

		weighted1, attn1 = self.qkv_attention(joint, key1, value1, dropout=self.dropouts[0])
		weighted2, attn2 = self.qkv_attention(joint, key2, value2, dropout=self.dropouts[1])
		print("DenseCoAttn weighted1 : ", weighted1.shape)
		print("DenseCoAttn weighted2 : ", weighted2.shape)

		return weighted1, weighted2

	def qkv_attention(self, query, key, value, dropout=None):
		d_k = query.size(-1)
		scores = torch.bmm(key, query) / math.sqrt(d_k)
		scores = torch.tanh(scores) # C_v, C_a
		if dropout:
			scores = dropout(scores)

		weighted = torch.tanh(torch.bmm(value, scores))
		return self.relu(weighted), scores # self.relu(weighted) == H_v, H_a


In [3]:
import torch.nn as nn

class NormalSubLayer(nn.Module):
    def __init__(self, dim1, dim2, dropout): # dim1, dim2 = 512, 512
        super(NormalSubLayer, self).__init__()
        self.dense_coattn = DenseCoAttn(dim1, dim2, dropout)
        self.linears = nn.ModuleList([
            nn.Sequential(
                nn.Linear(dim1 + dim2, dim1), # 1024, 512
                nn.ReLU(inplace=False),
                nn.Dropout(p=dropout),
            ),
            nn.Sequential(
                nn.Linear(dim1 + dim2, dim2),
                nn.ReLU(inplace=False),
                nn.Dropout(p=dropout),
            )
        ])

    def forward(self, data1, data2):
        weighted1, weighted2 = self.dense_coattn(data1, data2) # weighted1, weighted2 = 1024, 1024
        data1 = data1 + self.linears[0](weighted1) # X_att,v^t
        data2 = data2 + self.linears[1](weighted2) # X_att,a^t

        print("DCNLayer X_att,v : " , data1.shape)
        print("DCNLayer X_att,a : " , data2.shape)

        return data1, data2


class DCNLayer(nn.Module):
    def __init__(self, dim1, dim2, num_seq, dropout): # dim1, dim2 = 512, 512
        super(DCNLayer, self).__init__()
        self.dcn_layers = nn.ModuleList([NormalSubLayer(dim1, dim2, dropout) for _ in range(num_seq)]) # 여기서 t-th iteration만큼 계산

    def forward(self, data1, data2):
        for dense_coattn in self.dcn_layers:
            data1, data2 = dense_coattn(data1, data2)

        return data1, data2

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys

class BottomUpExtract(nn.Module):
	def __init__(self, emed_dim, dim):
		super(BottomUpExtract, self).__init__()
		self.attn = PositionAttn(emed_dim, dim)

	def forward(self, video, audio):
		feat = self.attn(video, audio)

		return feat

# audio-guided attention
class PositionAttn(nn.Module):

	def __init__(self, embed_dim, dim):
		super(PositionAttn, self).__init__()
		self.affine_audio = nn.Linear(embed_dim, dim)
		self.affine_video = nn.Linear(512, dim)
		self.affine_v = nn.Linear(dim, 49, bias=False)
		self.affine_g = nn.Linear(dim, 49, bias=False)
		self.affine_h = nn.Linear(49, 1, bias=False)
		self.affine_feat = nn.Linear(512, dim)
		self.relu = nn.ReLU()

	def forward(self, video, audio):
		v_t = video.view(video.size(0) * video.size(1), -1, 512).contiguous()
		V = v_t

		# Audio-guided visual attention
		v_t = self.relu(self.affine_video(v_t))
		a_t = audio.view(-1, audio.size(-1))

		a_t = self.relu(self.affine_audio(a_t))

		content_v = self.affine_v(v_t) \
					+ self.affine_g(a_t).unsqueeze(2)

		z_t = self.affine_h((torch.tanh(content_v))).squeeze(2)

		alpha_t = F.softmax(z_t, dim=-1).view(z_t.size(0), -1, z_t.size(1))  # attention map

		c_t = torch.bmm(alpha_t, V).view(-1, 512)
		video_t = c_t.view(video.size(0), -1, 512)

		video_t = self.affine_feat(video_t)

		return video_t


In [5]:
import torch
import torch.nn as nn

class Chomp1d(nn.Module):
    """Chomp1d removes extra padding after Conv1d"""
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size]

class TemporalBlock(nn.Module):
    """Single Temporal Block"""
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = nn.Conv1d(n_inputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = nn.Conv1d(n_outputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation)
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)

class TemporalConvNet(nn.Module):
    """Stacked Temporal Blocks"""
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


In [6]:
from __future__ import absolute_import
from __future__ import division

from torch.nn import init
import torch
from torch import nn
from torch.nn import functional as F

class TLAB(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(TLAB, self).__init__()
        self.lstm = LSTM(input_dim, hidden_dim, num_layers=2, dropout=0.1, residual_embeddings=True)
        self.tcn = TemporalConvNet(
            num_inputs=input_dim, num_channels=[hidden_dim, hidden_dim], kernel_size=3, dropout=0.1
        )
        # Attention Mechanism
        self.query_fc = nn.Linear(hidden_dim, hidden_dim)  # Q (from LSTM output)
        self.key_fc = nn.Linear(hidden_dim, hidden_dim)    # K (from TCN output)
        self.value_fc = nn.Linear(hidden_dim, hidden_dim)  # V (from TCN output)
        self.attention_softmax = nn.Softmax(dim=-1)        # Softmax for Attention weights


    def forward(self, x):
        # Extract global and local features
        lstm_feat = self.lstm(x)  # Output: (batch, seq_len, hidden_dim)
        print("lstm_feat Output : ", lstm_feat.shape)
        tcn_feat = self.tcn(x.transpose(1, 2)).transpose(1, 2)  # Output: (batch, seq_len, hidden_dim)
        print("tcn_feat Output : ", tcn_feat.shape)

        # Compute Q, K, V
        Q = self.query_fc(lstm_feat)  # (batch, seq_len, hidden_dim)
        K = self.key_fc(tcn_feat)    # (batch, seq_len, hidden_dim)
        V = self.value_fc(tcn_feat)  # (batch, seq_len, hidden_dim)

        # Compute Attention weights
        attention_scores = torch.matmul(Q, K.transpose(-1, -2))  # (batch, seq_len, seq_len)
        attention_weights = self.attention_softmax(attention_scores)  # Normalize scores

        # Weighted sum of V
        attended_feat = torch.matmul(attention_weights, V)  # (batch, seq_len, hidden_dim)

        # Combine attended features with global (LSTM) features
        tlab_output_feat = lstm_feat + attended_feat  # (batch, seq_len, hidden_dim)
        print("TLAB output shape : ", tlab_output_feat.shape)

        return tlab_output_feat


class LSTM_CAM(nn.Module):
    def __init__(self):
        super(LSTM_CAM, self).__init__()
        self.coattn = DCNLayer(512, 512, 1, 0.6)
        self.avga = BottomUpExtract(512, 512)

        # Audio and Video TLABs
        self.audio_tlab = TLAB(512, 512)
        self.video_tlab = TLAB(512, 512)


        # self.audio_extract = LSTM(512, 512, 2, 0.1, residual_embeddings=True) # output: (batch, sequence, features)
        # self.video_extract = LSTM(512, 512, 2, 0.1, residual_embeddings=True) # output: (batch, sequence, features)

        self.vregressor = nn.Sequential(nn.Linear(512, 128),
                                        nn.ReLU(inplace=True),
                                     nn.Dropout(0.6),
                                 nn.Linear(128, 1))

        self.Joint = LSTM(1024, 512, 2, dropout=0, residual_embeddings=True)

        self.aregressor = nn.Sequential(nn.Linear(512, 128),
                                        nn.ReLU(inplace=True),
                                     nn.Dropout(0.6),
                                 nn.Linear(128, 1))

        self.init_weights()

    def init_weights(net, init_type='xavier', init_gain=1):

        if torch.cuda.is_available():
            net.cuda()

        def init_func(m):  # define the initialization function
            classname = m.__class__.__name__
            if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
                if init_type == 'normal':
                    init.uniform_(m.weight.data, 0.0, init_gain)
                elif init_type == 'xavier':
                    init.xavier_uniform_(m.weight.data, gain=init_gain)
                elif init_type == 'kaiming':
                    init.kaiming_uniform_(m.weight.data, a=0, mode='fan_in')
                elif init_type == 'orthogonal':
                    init.orthogonal_(m.weight.data, gain=init_gain)
                else:
                    raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
                if hasattr(m, 'bias') and m.bias is not None:
                    init.constant_(m.bias.data, 0.0)

        print('initialize network with %s' % init_type)
        net.apply(init_func)  # apply the initialization function <init_func>


    def forward(self, f1_norm, f2_norm):
        video = F.normalize(f2_norm, dim=-1)
        audio = F.normalize(f1_norm, dim=-1)

        # video = self.avga(video, audio)
        
        # # Tried with LSTMs also
        # audio_tcn = self.audio_tcn(audio)
        # audio_lstm = self.audio_extract(audio)

        audio = self.audio_tlab(audio)
        print("audio_tlab feat : ", audio.shape)

        video = self.avga(video, audio)
        
        video = self.video_tlab(video)
        print("video_tlab feat : ", video.shape)
        # video_tcn = self.video_tcn(video)
        # video_lstm = self.video_extract(video)

        video, audio = self.coattn(video, audio)

        audiovisualfeatures = torch.cat((video, audio), -1)
        
        audiovisualfeatures = self.Joint(audiovisualfeatures)
        vouts = self.vregressor(audiovisualfeatures) #.transpose(0,1))
        aouts = self.aregressor(audiovisualfeatures) #.transpose(0,1))

        return vouts.squeeze(2), aouts.squeeze(2)  #final_aud_feat.transpose(1,2), final_vis_feat.transpose(1,2)

In [7]:
from models.tsav import TwoStreamAuralVisualModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = 'ABAW2020TNT/model2/TSAV_Sub4_544k.pth.tar' # path to the model
model = TwoStreamAuralVisualModel(num_channels=4)
saved_model = torch.load(model_path)
model.load_state_dict(saved_model['state_dict'])

new_first_layer = nn.Conv3d(in_channels=3,
					out_channels=model.video_model.r2plus1d.stem[0].out_channels,
					kernel_size=model.video_model.r2plus1d.stem[0].kernel_size,
					stride=model.video_model.r2plus1d.stem[0].stride,
					padding=model.video_model.r2plus1d.stem[0].padding,
					bias=False)

new_first_layer.weight.data = model.video_model.r2plus1d.stem[0].weight.data[:, 0:3]
model.video_model.r2plus1d.stem[0] = new_first_layer
model = nn.DataParallel(model)
model = model.to(device)


# Train

In [8]:
from __future__ import print_function
import os
import time

import torch
import torch.nn.parallel
import torch.optim
from tqdm import tqdm

import utils.utils as utils
from EvaluationMetrics.cccmetric import ccc

import logging

torch.autograd.set_detect_anomaly(True)

learning_rate_decay_start = 5  # 50
learning_rate_decay_every = 2 # 5
learning_rate_decay_rate = 0.8 # 0.9
total_epoch = 30
lr = 0.0001
scaler = torch.cuda.amp.GradScaler()

def train(train_loader, model, criterion, optimizer, scheduler, epoch, lr, cam, time_chk_path):
	print('\nEpoch: %d' % epoch)
	global Train_acc
	#wandb.watch(audiovisual_model, log_freq=100)
	#wandb.watch(cam, log_freq=100)

	# switch to train mode
	#audiovisual_model.train()
	model.eval()
	cam.train()

	epoch_loss = 0
	vout = list()
	vtar = list()

	aout = list()
	atar = list()

	if epoch > learning_rate_decay_start and learning_rate_decay_start >= 0:
		frac = (epoch - learning_rate_decay_start) // learning_rate_decay_every
		decay_factor = learning_rate_decay_rate ** frac
		current_lr = lr * decay_factor
		utils.set_lr(optimizer, current_lr)  # set the decayed rate
	else:
		current_lr = lr
	######## chckpoint 없을 때 이거부터 ##########
	utils.set_lr(optimizer, current_lr)
	############################################
	print('learning_rate: %s' % str(current_lr))
	logging.info("Learning rate")
	logging.info(current_lr)
	#torch.cuda.synchronize()
	#t1 = time.time()
	n = 0
	if time_chk_path:
		time_chk_file = os.path.join(time_chk_path, "time_chk.txt")

	global_vid_fts, global_aud_fts= None, None
  

	for batch_idx, (visualdata, audiodata, labels_V, labels_A) in tqdm(enumerate(train_loader),
				 										 total=len(train_loader), position=0, leave=True):
     
		print("====" * 20)
		print("Batch Index : ", batch_idx)

		print("====" * 20)
		print("visualdata : ", visualdata.shape)
		print("audiodata : ", audiodata.shape)
		print("labels_V : ", labels_V.shape)
		print("labels_A : ", labels_A.shape)

		print("====" * 20)

		optimizer.zero_grad(set_to_none=True)
		audiodata = audiodata.cuda()#.unsqueeze(2)

		visualdata = visualdata.cuda()#permute(0,4,1,2,3).cuda()
  
		st2 = time.time()


		with torch.cuda.amp.autocast():
			with torch.no_grad():
				b, seq_t, c, subseq_t, h, w = visualdata.size()
				visual_feats = torch.empty((b, seq_t, 25088), dtype=visualdata.dtype, device = visualdata.device)
				aud_feats = torch.empty((b, seq_t, 512), dtype=visualdata.dtype, device = visualdata.device)

				for i in range(visualdata.shape[0]):
					st1 = time.time()
					aud_feat, visualfeat, _ = model(audiodata[i,:,:,:], visualdata[i, :, :, :,:,:])
					ed1 = time.time()

					pre_trained_model_time = ed1 - st1
					if time_chk_path:
						with open(time_chk_file, 'a') as f:
							f.write(f"Time pre_trained_model: {pre_trained_model_time}\n")
					# visual_feats[i,:,:] = visualfeat
					visual_feats[i,:,:] = visualfeat.view(seq_t, -1)
					aud_feats[i,:,:] = aud_feat

			st2 = time.time()
			# if batch_idx==0:
				# audiovisual_vouts,audiovisual_aouts, global_vid_fts, global_aud_fts = cam(aud_feats, visual_feats)
			# else:
				# audiovisual_vouts,audiovisual_aouts, global_vid_fts, global_aud_fts = cam(aud_feats, visual_feats, global_vid_fts, global_aud_fts)
	
			audiovisual_vouts,audiovisual_aouts = cam(aud_feats, visual_feats)
			print("audiovisual_vouts : " , audiovisual_vouts.shape)
			print("audiovisual_aouts : " , audiovisual_aouts.shape)
			ed2 = time.time()
   
			time_cam_model= ed2 - st2
			if time_chk_path:
				with open(time_chk_file, 'a') as f:
					f.write(f"Time cam model: {time_cam_model}\n")
					f.write(f"Epoch: {epoch}\n")
					f.write(f"batch_idx: {batch_idx}\n")
					f.write("----"*20)
					f.write("\n")
				f.close()

			voutputs = audiovisual_vouts.view(-1, audiovisual_vouts.shape[0]*audiovisual_vouts.shape[1])
			aoutputs = audiovisual_aouts.view(-1, audiovisual_aouts.shape[0]*audiovisual_aouts.shape[1])
			vtargets = labels_V.view(-1, labels_V.shape[0]*labels_V.shape[1]).cuda()
			atargets = labels_A.view(-1, labels_A.shape[0]*labels_A.shape[1]).cuda()
   
			v_loss = criterion(voutputs, vtargets)
			a_loss = criterion(aoutputs, atargets)
   
			final_loss = v_loss + a_loss
   
			epoch_loss += final_loss.cpu().data.numpy()

		# scaler.scale(final_loss).backward(retain_graph=True)
		# scaler.step(optimizer)
		# scaler.update()

		with torch.autograd.set_detect_anomaly(True):
			final_loss.backward(retain_graph=True)
			optimizer.step()
		n = n + 1

		vout = vout + voutputs.squeeze(0).detach().cpu().tolist()
		vtar = vtar + vtargets.squeeze(0).detach().cpu().tolist()

		aout = aout + aoutputs.squeeze(0).detach().cpu().tolist()
		atar = atar + atargets.squeeze(0).detach().cpu().tolist()

		break
  
	scheduler.step(epoch_loss / n)

	if (len(vtar) > 1):
		train_vacc = ccc(vout, vtar)
		train_aacc = ccc(aout, atar)
	else:
		train_acc = 0
	print("Train Accuracy")
	print(train_vacc)
	print(train_aacc)
 
	return train_vacc, train_aacc, final_loss

# Main

In [9]:
import os
import time
import random
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import gc
import logging
import numpy as np
from models.tsav import TwoStreamAuralVisualModel
from datasets.dataset_new import ImageList
from torch.optim.lr_scheduler import ReduceLROnPlateau
from losses.loss import CCCLoss
from datetime import datetime, timedelta
from torch import nn
import json
from warnings import filterwarnings
filterwarnings("ignore")

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
is_time_chk = False

best_Val_acc = 0  # best PrivateTest accuracy
#best_Val_acc = 0  # best PrivateTest accuracy
best_Val_acc_epoch = 0

TrainingAccuracy_V = []
TrainingAccuracy_A = []
ValidationAccuracy_V = []
ValidationAccuracy_A = []

Logfile_name = "LogFiles/" + "log_file.log"
logging.basicConfig(filename=Logfile_name, level=logging.INFO)

SEED = int(0)
    
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)
torch.cuda.manual_seed_all(SEED)


class TrainPadSequence:
	def __call__(self, sorted_batch):
		sequences = [x[0] for x in sorted_batch]
		aud_sequences = [x[1] for x in sorted_batch]
		spec_dim = []

		for aud in aud_sequences:
			spec_dim.append(aud.shape[3])

		max_spec_dim = max(spec_dim)
		audio_features = torch.zeros(len(spec_dim), 16, 1, 64, max_spec_dim)
		for batch_idx, spectrogram in enumerate(aud_sequences):
			if spectrogram.shape[2] < max_spec_dim:
				audio_features[batch_idx, :, :, :, -spectrogram.shape[3]:] = spectrogram
			else:
				audio_features[batch_idx, :,:, :, :] = spectrogram

		labelV = [x[2] for x in sorted_batch]
		labelA = [x[3] for x in sorted_batch]
		visual_sequences = torch.stack(sequences)
		labelsV = torch.stack(labelV)
		labelsA = torch.stack(labelA)

		return visual_sequences, audio_features, labelsV, labelsA


class ValPadSequence:
	def __call__(self, sorted_batch):

		sequences = [x[0] for x in sorted_batch]
		aud_sequences = [x[1] for x in sorted_batch]
		spec_dim = []
		for aud in aud_sequences:
			spec_dim.append(aud.shape[3])

		max_spec_dim = max(spec_dim)
		audio_features = torch.zeros(len(spec_dim), 16, 1, 64, max_spec_dim)
		for batch_idx, spectrogram in enumerate(aud_sequences):
			if spectrogram.shape[2] < max_spec_dim:
				audio_features[batch_idx, :, :, :, -spectrogram.shape[3]:] = spectrogram
			else:
				audio_features[batch_idx, :,:, :, :] = spectrogram

		frameids = [x[2] for x in sorted_batch]
		v_ids = [x[3] for x in sorted_batch]
		v_lengths = [x[4] for x in sorted_batch]
		labelV = [x[5] for x in sorted_batch]
		labelA = [x[6] for x in sorted_batch]

		visual_sequences = torch.stack(sequences)
		labelsV = torch.stack(labelV)
		labelsA = torch.stack(labelA)
		return visual_sequences, audio_features, frameids, v_ids, v_lengths, labelsV, labelsA


if not os.path.isdir("SavedWeights"):
	os.makedirs("SavedWeights", exist_ok=True)

weight_save_path = "SavedWeights"

result_save_path ="save"
if not os.path.exists(result_save_path):
    os.makedirs(result_save_path)

### Loading audiovisual model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_path = 'ABAW2020TNT/model2/TSAV_Sub4_544k.pth.tar' # path to the model
model = TwoStreamAuralVisualModel(num_channels=4)
saved_model = torch.load(model_path)
model.load_state_dict(saved_model['state_dict'])

new_first_layer = nn.Conv3d(in_channels=3,
					out_channels=model.video_model.r2plus1d.stem[0].out_channels,
					kernel_size=model.video_model.r2plus1d.stem[0].kernel_size,
					stride=model.video_model.r2plus1d.stem[0].stride,
					padding=model.video_model.r2plus1d.stem[0].padding,
					bias=False)

new_first_layer.weight.data = model.video_model.r2plus1d.stem[0].weight.data[:, 0:3]
model.video_model.r2plus1d.stem[0] = new_first_layer
model = nn.DataParallel(model)
model = model.to(device)

### Freezing the model
for p in model.parameters():
	p.requires_grad = False
for p in model.children():
	p.train(False)
 
fusion_model = LSTM_CAM()

print_model_name = fusion_model.__class__.__name__
print("Fusion Model : ", print_model_name)

fusion_model = fusion_model.to(device=device)

print('==> Preparing data..')

def matching_files(root_path, anno_path):
	anno_list = []
	for f in os.listdir(anno_path):
		anno_list.append(f.split(".")[0])
  
	root_path_list = os.listdir(root_path)
	
	for f in os.listdir(root_path):
		if not f in anno_list:
			del root_path_list[root_path_list.index(f)]

	return root_path_list

def train_val_test_split(root_path, anno_path, seed=0):
	random.seed(seed)
	trial_data = matching_files(root_path, anno_path)
 
	fname_dict = {i:f for i,f in enumerate(trial_data)}
	length = len(fname_dict)
 
	print("full trial length: ", len(fname_dict))

	train_set = []
	valid_set = []
	test_set = []
 
	train_list_idx = random.sample(fname_dict.keys(), int(length*0.6))
	for i in train_list_idx:
		train_set.append(fname_dict[i]+".csv")
		del fname_dict[i]
		
	valid_list_idx = random.sample(fname_dict.keys(), int(length*0.2))
	for i in valid_list_idx:
		valid_set.append(fname_dict[i]+".csv")
		del fname_dict[i]

	test_list_idx = random.sample(fname_dict.keys(), int(length*0.2))    
	for i in test_list_idx:
		test_set.append(fname_dict[i]+".csv")
		del fname_dict[i]
  
	return train_set, valid_set, test_set
    
with open('config_file.json', 'r') as f:
	configuration = json.load(f)

dataset_rootpath = configuration['dataset_rootpath']
dataset_wavspath = configuration['dataset_wavspath']
dataset_labelpath = configuration['labelpath']

def load_partition_set(partition_path, seed):
	import json

	with open(partition_path, 'r') as f:    
		seed_data = json.load(f)

	seed_data_train = seed_data[f'seed_{seed}']['Train_Set']
	seed_data_valid = seed_data[f'seed_{seed}']['Validation_Set']
	seed_data_test  = seed_data[f'seed_{seed}']['Test_Set']
 
	seed_data_train = [fn + ".csv" for fn in seed_data_train]
	seed_data_valid = [fn + ".csv" for fn in seed_data_valid]
	seed_data_test  = [fn + ".csv" for fn in seed_data_test ]

	return seed_data_train, seed_data_valid, seed_data_test

partition_path = "../data/Affwild2/seed_data.json"
 
train_set, valid_set, test_set = load_partition_set(partition_path, SEED)

init_time = datetime.now()
init_time = init_time.strftime('%m%d_%H%M')

root_time_chk_dir = "time_chk"

time_chk_path = None

print("Train Data")
traindataset = ImageList(root=configuration['dataset_rootpath'], fileList=train_set, labelPath=dataset_labelpath,
                        audList=configuration['dataset_wavspath'], length=configuration['train_params']['seq_length'],
                        flag='train', stride=configuration['train_params']['stride'], dilation = configuration['train_params']['dilation'],
                        subseq_length = configuration['train_params']['subseq_length'], time_chk_path=time_chk_path)
trainloader = torch.utils.data.DataLoader(
                traindataset, collate_fn=TrainPadSequence(),
                **configuration['train_params']['loader_params'])
print("Number of Train samples:" + str(len(traindataset)))

criterion = CCCLoss(digitize_num=1).cuda()
optimizer = torch.optim.Adam(fusion_model.parameters(),# filter(lambda p: p.requires_grad, multimedia_model.parameters()),
								configuration['model_params']['lr'])

scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10, verbose=True)

cnt = 0
fusion_model_name = 'gat'
lr = 0.0001

initialize network with xavier
Fusion Model :  LSTM_CAM
==> Preparing data..
Train Data
Number of Sequences: 247
Number of Train samples:68709


In [10]:
for epoch in range(0, total_epoch):
	epoch_tic = time.time()
	logging.info("Epoch")
	logging.info(epoch)

	# train for one epoch
	Training_vacc, Training_aacc, Training_loss = train(trainloader, model, criterion, optimizer, scheduler, epoch, lr, fusion_model, time_chk_path=time_chk_path)
	break


Epoch: 0
learning_rate: 0.0001


  0%|          | 0/4295 [00:00<?, ?it/s]

Batch Index :  0
visualdata :  torch.Size([16, 16, 3, 8, 112, 112])
audiodata :  torch.Size([16, 16, 1, 64, 104])
labels_V :  torch.Size([16, 16])
labels_A :  torch.Size([16, 16])
LSTM inputs :  torch.Size([16, 16, 512])
lstm_feat Output :  torch.Size([16, 16, 512])
tcn_feat Output :  torch.Size([16, 16, 512])
TLAB output shape :  torch.Size([16, 16, 512])
audio_tlab feat :  torch.Size([16, 16, 512])
LSTM inputs :  torch.Size([16, 16, 512])
lstm_feat Output :  torch.Size([16, 16, 512])
tcn_feat Output :  torch.Size([16, 16, 512])
TLAB output shape :  torch.Size([16, 16, 512])
video_tlab feat :  torch.Size([16, 16, 512])
DenseCoAttn input value1(video) :  torch.Size([16, 16, 512])
DenseCoAttn input value2(audio) :  torch.Size([16, 16, 512])
DenseCoAttn joint representation :  torch.Size([16, 16, 1024])
DenseCoAttn X_v^T :  torch.Size([16, 512, 16])
DenseCoAttn X_a^T :  torch.Size([16, 512, 16])
DenseCoAttn value1 after value_linear :  torch.Size([16, 16, 512])
DenseCoAttn value2 after v

  0%|          | 0/4295 [00:10<?, ?it/s]

Train Accuracy
-0.018203811194385407
0.09929763250821377





: 

: 