In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive/carknows
%ls

/content/drive/MyDrive/carknows
Cardataset_path.ipynb  [0m[01;34mdata[0m/  [01;34mFineGym[0m/  [01;34mresults[0m/  [01;34mrun[0m/  transforms.py


In [None]:
# !mkdir FineGym
# !unzip /content/drive/MyDrive/carknows/data/FineGym.zip -d /content/FineGym

Build local file system on Colab

In [None]:
!mkdir /content/Brain4Cars
# # !cd /content/Brain4Cars
# # !mkdir Data Lab
!mkdir /content/Brain4Cars/Data 
!mkdir /content/Brain4Cars/lab
!unzip /content/drive/MyDrive/carknows/data/Brain4Car/clip_lab.zip -d /content/Brain4Cars/lab
!unzip /content/drive/MyDrive/carknows/data/Brain4Car/face.zip -d /content/Brain4Cars/Data

In [None]:
%load_ext tensorboard
%tensorboard --logdir /content/drive/MyDrive/carknows/run/run_10/models/Jan15_08-45-39_eedb267c90f3

In [None]:
!nvidia-smi

In [None]:
import os
import torch
import cv2
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from PIL import Image
import scipy.io as scio
from torchvision.transforms import transforms


class Path(object):
    @staticmethod
    def db_dir(database):
        if database == 'face':
            # folder that contains class labels
            root_dir = '/content/Brain4Cars/Data'

            flow_dir = '/content/Brain4Cars/Data'

            # Save preprocess data into output_dir
            output_dir = '/content/Brain4Cars/Data'

            seq_lab_dir = '/content/Brain4Cars/lab/clip_lab'

        elif database == 'gtea':

            root_dir = '/content/drive/MyDrive/carknows/data/GTEA/GTEA_dt'

            flow_dir = '/content/drive/MyDrive/carknows/data/GTEA/GTEA_dt'

            # Save preprocess data into output_dir
            output_dir = '/content/drive/MyDrive/carknows/data/GTEA/GTEA_dt'

            seq_lab_dir = '/content/drive/MyDrive/carknows/data/GTEA/GTEA_lab'

        elif database == 'finegym':

            root_dir = '/content/FineGym/finegym_data'

            flow_dir = '/content/FineGym/finegym_data'

            # Save preprocess data into output_dir
            output_dir = '/content/FineGym/finegym_data'

            seq_lab_dir = '/content/FineGym/finegym_label'

        return root_dir, output_dir, seq_lab_dir, flow_dir


class CarDataset_multi_2(Dataset):
    def __init__(self, dataset='face', split='val', clip_len=16, transform1=None, transform2=None):
        self.root_dir, self.output_dir, self.seq_lab_dir, self.flow_dir = Path.db_dir(dataset)
        folder = os.path.join(self.output_dir, split)
        print('folder:', folder)
        seqlab_folder = os.path.join(self.seq_lab_dir, split)

        self.clip_len = clip_len
        self.split = split

        self.fnames, labels, self.flabels = [], [], []

        self.resize_height = 128
        self.resize_width = 128
        self.crop_size = 112
        self.dataset = dataset

        if dataset =='gtea':
            self.resize_height_resnet = 200
            self.resize_width_resnet = 320
        elif dataset == 'face':
            self.resize_height_resnet = 224
            self.resize_width_resnet = 224
        elif dataset == 'finegym':
            self.resize_height_resnet = 180
            self.resize_width_resnet = 320

        self.transform1 = transform1
        self.transform2 = transform2

        for label in sorted(os.listdir(folder)):
            for fname in os.listdir(os.path.join(folder, label)):
                self.fnames.append(os.path.join(folder, label, fname))
                self.flabels.append(os.path.join(seqlab_folder, label, fname))

                labels.append(label)
        assert len(labels) == len(self.fnames)
        print('Number of {} videos: {:d}'.format(split, len(self.fnames)))

        # Prepare a mapping between the label names (strings) and indices (ints)
        self.label2index = {label: index for index, label in enumerate(sorted(set(labels)))}
        # Convert the list of label names into an array of label indices
        self.label_array = np.array([self.label2index[label] for label in labels], dtype=int)

    def __len__(self):
        print('len')
        return len(self.fnames)

    def __getitem__(self, index):
        buffer, buffer_resnet = self.load_frames(self.fnames[index])
        buffer, buffer_resnet = self.crop(buffer, buffer_resnet, self.clip_len, self.crop_size)
        labels = np.array(self.label_array[index])
        seq_labels = self.get_seq_labels_from_mat(self.flabels[index], self.dataset)

        buffer, buffer_resnet = self.to_tensor(buffer, buffer_resnet, self.transform1, self.transform2)

        return buffer, buffer_resnet, torch.from_numpy(labels), seq_labels

    def load_frames(self, file_dir):
        frames = sorted([os.path.join(file_dir, img) for img in os.listdir(file_dir)])
        frame_count = len(frames)

        buffer = np.empty((frame_count, self.resize_height, self.resize_width, 3), np.dtype('float32'))
        buffer_resnet = np.empty((frame_count, self.resize_height_resnet, self.resize_width_resnet, 3),
                                 np.dtype('float32'))
        resize = (self.resize_height, self.resize_width)
        for i, frame_name in enumerate(frames):

            frame_pil = Image.open(frame_name)
            frame_pil_re = frame_pil.resize((self.resize_height, self.resize_width))

            buffer[i] = frame_pil_re
            buffer_resnet[i] = frame_pil

        return buffer, buffer_resnet


    def to_tensor(self, buffer, buffer_resnet, use_transform1, use_transform2):
        buffer1 = torch.from_numpy(buffer) / 255

        if use_transform1 is not None:
            buffer = use_transform1(buffer1)
        else:
            buffer = buffer.transpose((3, 0, 1, 2)) 
            buffer = torch.from_numpy(buffer)

        if use_transform2 is not None:
            buffer_resnet = use_transform2(buffer_resnet)
        else:
            buffer_resnet = buffer_resnet.transpose((0, 3, 1, 2))  # used for C3D
            buffer_resnet = torch.from_numpy(buffer_resnet)

        return buffer, buffer_resnet


    def crop(self, buffer, buffer_resnet, clip_len, crop_size):
        pred_con = 5

        bound_sep = 5

        if clip_len > 120:
            split_index = np.linspace(0, clip_len, clip_len)
            split_index = np.ceil(split_index)
            split_index = split_index.astype('int64')
        else:
            bound = np.rint(buffer.shape[0] / bound_sep)

            begin_index = np.random.randint(0, bound)
            upper_bound = buffer.shape[0] - pred_con
            end_index = np.random.randint(bound * (bound_sep - 1), upper_bound)
            split_index = np.linspace(begin_index, end_index, clip_len)
            split_index = np.ceil(split_index)
            split_index = split_index.astype('int64')

        # Randomly select start indices in order to crop the video
        height_index = np.random.randint(buffer.shape[1] - crop_size)
        width_index = np.random.randint(buffer.shape[2] - crop_size)

        buffer = buffer[split_index,
                 height_index:height_index + crop_size,
                 width_index:width_index + crop_size, :]

        buffer_resnet = buffer_resnet[split_index, :, :, :]

        return buffer, buffer_resnet

    def get_seq_labels_from_mat(self, flabels, dataset):
        if dataset == 'face':
            mat_path = os.path.join(flabels, 'lab.mat')
            mat_lab = scio.loadmat(mat_path)
            mat_lab_seq = mat_lab['lab_gt']
            seq_lab = mat_lab_seq.astype("int32")
        elif dataset == 'gtea':
            mat_path = os.path.join(flabels, 'state.mat')
            mat_lab = scio.loadmat(mat_path)
            mat_lab_seq = mat_lab['state_p']
            seq_lab = mat_lab_seq.astype("int32")
        elif dataset == 'finegym':
            mat_path = os.path.join(flabels, 'state.mat')
            mat_lab = scio.loadmat(mat_path)
            mat_lab_seq = mat_lab['state']
            seq_lab = mat_lab_seq.astype("int32")
            
        # identify if any folders have incomplete data
        if seq_lab.shape[1] < 151:
            print('len_seq_lab', seq_lab.shape, seq_lab)
            print('mat_path', mat_path)
        return seq_lab

In [None]:
import torch
import torch.nn as nn

print('main')
dataset = 'face'
clip_len = 150

class ConvertBHWCtoBCHW(nn.Module):
    """Convert tensor from (B, H, W, C) to (B, C, H, W)
    """
    def forward(self, vid: torch.Tensor) -> torch.Tensor:
        return vid.permute(0, 3, 1, 2)

class ConvertBCHWtoCBHW(nn.Module):
    """Convert tensor from (B, C, H, W) to (C, B, H, W)
    """
    def forward(self, vid: torch.Tensor) -> torch.Tensor:
        return vid.permute(1, 0, 2, 3)

mean = [0.43216, 0.394666, 0.37645]
std = [0.22803, 0.22145, 0.216989]

resize_size = 224
crop_size = 112

mean_res = [0.485, 0.456, 0.406]
std_res = [0.229, 0.224, 0.225]

trans1 = [
    ConvertBHWCtoBCHW(),
    transforms.ConvertImageDtype(torch.float32),
]
trans1.extend([
    transforms.Normalize(mean=mean, std=std),
    ConvertBCHWtoCBHW()])

trans2 = [
    ConvertBHWCtoBCHW(),
    transforms.ConvertImageDtype(torch.float32),
]
trans2.extend([
    transforms.Normalize(mean=mean_res, std=std_res)])

transform1_t = transforms.Compose(trans1)
transform2_t = transforms.Compose(trans2)

train_dataloader = DataLoader(CarDataset_multi_2(dataset=dataset, split='train', clip_len=clip_len,
                                              transform1=transform1_t, transform2=None),
                              batch_size=3, shuffle=True, num_workers=4, pin_memory=True)

val_dataloader = DataLoader(CarDataset_multi_2(dataset=dataset, split='val', clip_len=clip_len,
                                            transform1=transform1_t, transform2=None),
                            batch_size=3, num_workers=4, pin_memory=True)

test_dataloader = DataLoader(CarDataset_multi_2(dataset=dataset, split='test', clip_len=clip_len,
                                              transform1=transform1_t, transform2=None),
                              batch_size=3, shuffle=True, num_workers=4, pin_memory=True)


for i, sample in enumerate(train_dataloader):
    inputs1 = sample[0]
    inputs2 = sample[1]
    labels = sample[2]
    labels_seq = sample[3]
    print('main input 1', inputs1.size())
    print('main input 2', inputs2.size())
    print('main label', labels, labels.size())
    print('main label seq', labels_seq.size())

    if i == 1:
        break

LSTM_ANNO: extract high-level temporal features from each clip.

Fusion_net_2: fuse features from the 3d and 2d branches.

Prediction_net: fuse two-branch features and predict the activities for next clip. 


In [None]:
!pip install tensorboardX

import torch
import torch.nn as nn
import timeit
from datetime import datetime
import socket
import os
import glob
from tqdm import tqdm
import torch
from PIL import Image
from tensorboardX import SummaryWriter
from torch import nn, optim
from torch.utils.data import DataLoader
from collections import Counter
from torchvision.transforms import transforms
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable

class LSTM_ANNO(nn.Module):
    def __init__(self, num_classes=5):
        super(LSTM_ANNO, self).__init__()
        # defining encoder LSTM layers
        feat_dim = 512
        self.clip_len = 16
        self.lstm_hidden_size = 512
        self.num_classes = num_classes
        self.gru1 = nn.GRU(feat_dim, feat_dim//2, 2, batch_first=True, bidirectional=True)   # Fusion net: 1152, Concate: 1024, C3D: 4096, RES3D: 512
        self.gru2 = nn.GRU(feat_dim, feat_dim//2, num_layers=2, batch_first=True, bidirectional=True)

        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(feat_dim, feat_dim//4)
        # self.bn1 = nn.BatchNorm1d(512, momentum=0.01)
        self.dropout = nn.Dropout(p=0.5)
        self.fc_final_score = nn.Linear(feat_dim//4, self.num_classes)

    def forward(self, x):
        state = None
        gru_output, (h_n) = self.gru1(x, state)
        gru1_output = gru_output.clone()
        gru_output, (h_n) = self.gru2(gru_output)
        gru_output = self.relu(self.fc1(gru_output[:, -1, :]))
        final_score = self.fc_final_score(gru_output)

        return final_score, gru1_output


def basic_multi_fusion(spat_feat, temp_feat):
    return spat_feat*temp_feat


def basic_sum_fusion(spat_feat, temp_feat):
    return spat_feat+temp_feat


def basic_max_fusion(spat_feat, temp_feat):
    return torch.max(spat_feat, temp_feat)


def basic_concate_fusion(spat_feat, temp_feat):
    return torch.cat((spat_feat, temp_feat), 1)


class basic_conv_fusion(nn.Module):
    def __init__(self, pred_horizon):
        super(basic_conv_fusion, self).__init__()
        self.conv1d = nn.Conv1d(pred_horizon*2, 256, 3, stride=1)
        self.conv1d_2 = nn.Conv1d(2, 16, 3, stride=3)

    def forward(self, spat_feat, temp_feat):
        if len(spat_feat.size()) > 2:
            x = torch.cat((temp_feat, spat_feat), 1)
            x = self.conv1d(x)
        else:
            x = torch.stack((temp_feat, spat_feat), 1)
            x = self.conv1d_2(x)
            x = x.view(spat_feat.size()[0], -1)

        return x

class basic_conv_demen1(nn.Module):
    def __init__(self):
        super(basic_conv_demen1, self).__init__()
        # self.c1 = nn.Conv3d(in_channels=64, out_channels=16, kernel_size=3, padding=1)
        self.c2 = nn.Conv3d(in_channels=16, out_channels=4, kernel_size=3, padding=1)
        self.c3 = nn.Conv3d(in_channels=4, out_channels=1, kernel_size=3, padding=1)

        self.c4 = nn.Conv3d(in_channels=16, out_channels=1, kernel_size=3, padding=1)

    def forward(self, x):
        x0 = x.permute((0, 2, 1, 3, 4))
        # x = self.c1(x)
        x1 = self.c2(x0)
        x1 = self.c3(x1)

        x2 = self.c4(x0)
        x2 = x2.squeeze(1)
        print('x2.size', x2.size())

        x1 = x1.squeeze(1)

        x = x1+x2
        return x


class Fusion_net_2(nn.Module):
    def __init__(self, num_classes=4, fusion_ind=True, pred_horizon=8):
        super(Fusion_net_2, self).__init__()

        self.fc_hidden1 = 512          # 512， 1024, 2720(conv1d)
        self.fc_hidden2 = 512
        self.reduction_ratio = 2
        self.num_classes = num_classes

        self.relu = nn.ReLU()

        self.fc1_act = nn.Linear(self.fc_hidden1, self.fc_hidden2)
        self.fc2_act = nn.Linear(self.fc_hidden2, self.num_classes)

        self.fc1_int = nn.Linear(self.fc_hidden1, self.fc_hidden2)
        self.dropout = nn.Dropout(p=0.5)

        self.spat_weigh = nn.Parameter(torch.tensor([1.0]))
        self.temp_weigh = nn.Parameter(torch.tensor([1.0]))

        self.fusion_ind = fusion_ind
        self.convfusion = basic_conv_fusion(pred_horizon=pred_horizon)

    def forward(self, x_temp, x_spat):

        # if use bounded fusion method

        # if self.fusion_ind:
        #     self.temp_weigh = nn.Parameter(1 - self.spat_weigh)
        #     spat_weigh = self.spat_weigh.clone()
        #     temp_weigh = self.temp_weigh.clone()

        #     spat_weigh = spat_weigh.clamp(0, 1)
        #     temp_weigh = temp_weigh.clamp(0, 1)

        #     x_temp = temp_weigh * x_temp
        #     x_spat = spat_weigh * x_spat
        #     # print('x_temp param', temp_weigh, 'x_spat param', spat_weigh)

        x_temp = self.temp_weigh * x_temp
        x_spat = self.spat_weigh * x_spat

        # x_fuse = self.convfusion(x_spat, x_temp)
        x_fuse = basic_multi_fusion(x_spat, x_temp)

        x0 = self.fc1_int(x_fuse)
        x0 = x0.unsqueeze(0)
        x0 = x0.transpose(0, 1)

        x1 = self.relu(self.fc1_act(x_fuse))
        logits = self.fc2_act(x1)

        x1 = x1.unsqueeze(0)
        x1 = x1.transpose(0, 1)

        return x0, logits, x1, [self.spat_weigh, self.temp_weigh]

    def get_1x_lr_params(model):
        """
        This generator returns all the parameters for conv and two fc layers of the net.
        """
        b = [model.fc1_act, model.fc2_act, model.fc1_int, model.convfusion]
        for i in range(len(b)):
            for k in b[i].parameters():
                if k.requires_grad:
                    yield k


class prediction_net(nn.Module):
    def __init__(self, num_activity_classes=4, num_intent_classes=5, fusion_ind=True, pred_horizon=7):
        super(prediction_net, self).__init__()
        self.num_activity_classes = num_activity_classes
        self.num_intent_classes = num_intent_classes

        feat_dim = 512
        fused_dim = 512

        self.gru1 = nn.GRU(feat_dim, feat_dim//2, 2, batch_first=True, bidirectional=True)   # Fusion net: 1152, Concate: 1024, C3D: 4096, RES3D: 512
        self.gru2 = nn.GRU(feat_dim, feat_dim//2, 2, batch_first=True, bidirectional=True)   # Fusion net: 1152, Concate: 1024, C3D: 4096, RES3D: 512

        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(fused_dim, feat_dim//2)
        self.fc2 = nn.Linear(fused_dim, feat_dim//2)
        self.fc_final_intent_score = nn.Linear(feat_dim//2, self.num_intent_classes)
        self.fc_final_activity_score = nn.Linear(feat_dim//2, self.num_activity_classes)

        self.intent_weigh = nn.Parameter(torch.tensor([1.0]))
        self.act_weigh = nn.Parameter(torch.tensor([1.0]))
        # fusion_ind = False
        self.fusion_ind = fusion_ind
        self.convfusion = basic_conv_fusion(pred_horizon=pred_horizon)

    def forward(self, x_activity_feats, x_gru_out):
        state = None
        gru_output_activity, (h_n) = self.gru1(x_activity_feats)
        gru_output_activity, (h_n) = self.gru2(gru_output_activity)
        
        # uncomment if use bounded fusion method
        # if self.fusion_ind:
        #     self.act_weigh = nn.Parameter(1 - self.intent_weigh)

        #     intent_weigh = self.intent_weigh.clone()
        #     act_weigh = self.act_weigh.clone()

        #     intent_weigh = intent_weigh.clamp(0, 1)
        #     act_weigh = act_weigh.clamp(0, 1)
        #     # print('self.temp_weigh', self.temp_weigh)

        #     x_pred_sub_act = (act_weigh * gru_output_activity)
        #     x_pred_sub_int = (intent_weigh * x_gru_out)

        #     print('activity param', act_weigh, 'intent param', intent_weigh)
        # else:
        #     x_pred_sub_act = gru_output_activity
        #     x_pred_sub_int = x_gru_out

        x_pred_sub_act = self.act_weigh * gru_output_activity
        x_pred_sub_int = self.intent_weigh * x_gru_out

        # gru_output_fuse = self.convfusion(x_pred_sub_act, x_pred_sub_int)
        gru_output_fuse = basic_multi_fusion(x_pred_sub_act, x_pred_sub_int)

        gru_output1 = self.relu(self.fc1(gru_output_fuse[:, -1, :]))
        gru_output2 = self.relu(self.fc2(gru_output_fuse[:, -1, :]))

        final_intent_score = self.fc_final_intent_score(gru_output2)
        final_activity_score = self.fc_final_activity_score(gru_output1)

        return final_intent_score, final_activity_score, [self.intent_weigh, self.act_weigh]

    def get_1x_lr_params(model):
        """
        This generator returns all the parameters for conv and two fc layers of the net.
        """
        b = [model.gru1, model.gru2, model.fc1, model.fc2, model.fc_final_intent_score, model.fc_final_activity_score,
            model.convfusion]

        for i in range(len(b)):
            for k in b[i].parameters():
                if k.requires_grad:
                    yield k


inputs1 = torch.rand(3, 8, 512)
inputs2 = torch.rand(3, 8, 512)
net = prediction_net(num_activity_classes=4, num_intent_classes=5,pred_horizon=8)
output1, output2, [int_weigh, act_weigh] = net.forward(inputs1, inputs2)

print('outputs size:', output1.size(), output2.size())

Initiate either 3D Resnet-18 or R(2+1)D backbone

In [None]:
try:
    from torch.hub import load_state_dict_from_url
except ImportError:
    from torch.utils.model_zoo import load_url as load_state_dict_from_url

import torch.nn as nn
import torch
from typing import Type, Any, Callable, Union, List, Optional
import numpy as np

# __all__ = ['r3d_18', 'ResNet2d', 'resnet18']

model_urls = {
    'r3d_18': 'https://download.pytorch.org/models/r3d_18-b3b3357e.pth',
    'r2plus1d_18': 'https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth',
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
}

def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class ChannelGate_layer1(nn.Module):
    def __init__(self, pool_types=['avg', 'max']):
        super(ChannelGate_layer1, self).__init__()
        self.gate_channels = 16
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.gate_channels, self.gate_channels // 2),
            nn.ReLU(),
            nn.Linear(self.gate_channels // 2, self.gate_channels)
            )
        self.pool_types = pool_types

    def forward(self, x):
        x = x.permute((0, 2, 1, 3, 4))
        channel_att_sum = True
        for pool_type in self.pool_types:
            if pool_type == 'avg':
                avg_pool = nn.AvgPool3d((x.size(2), x.size(3), x.size(4)), stride=(x.size(2), x.size(3), x.size(4)))
                x_mid = avg_pool(x)
                channel_att_raw = self.mlp(x_mid)
            elif pool_type == 'max':
                max_pool = nn.MaxPool3d((x.size(2), x.size(3), x.size(4)), stride=(x.size(2), x.size(3), x.size(4)))
                x_mid = max_pool(x)
                channel_att_raw = self.mlp(x_mid)

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = torch.sigmoid(channel_att_raw).unsqueeze(2).unsqueeze(3).unsqueeze(4).expand_as(x)
        return x * scale


class ChannelGate_layer2(nn.Module):
    def __init__(self, pool_types=['avg', 'max']):
        super(ChannelGate_layer2, self).__init__()
        self.gate_channels = 8
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.gate_channels, self.gate_channels // 2),
            nn.ReLU(),
            nn.Linear(self.gate_channels // 2, self.gate_channels)
        )
        self.pool_types = pool_types

    def forward(self, x):
        x = x.permute((0, 2, 1, 3, 4))
        channel_att_sum = True
        for pool_type in self.pool_types:
            if pool_type == 'avg':
                avg_pool = nn.AvgPool3d((x.size(2), x.size(3), x.size(4)), stride=(x.size(2), x.size(3), x.size(4)))
                x_mid = avg_pool(x)
                channel_att_raw = self.mlp(x_mid)
            elif pool_type == 'max':
                max_pool = nn.MaxPool3d((x.size(2), x.size(3), x.size(4)), stride=(x.size(2), x.size(3), x.size(4)))
                x_mid = max_pool(x)
                channel_att_raw = self.mlp(x_mid)

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = torch.sigmoid(channel_att_raw).unsqueeze(2).unsqueeze(3).unsqueeze(4).expand_as(x)

        return x * scale


class ChannelGate_layer3(nn.Module):
    def __init__(self, pool_types=['avg', 'max']):
        super(ChannelGate_layer3, self).__init__()
        self.gate_channels = 4
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.gate_channels, self.gate_channels // 2),
            nn.ReLU(),
            nn.Linear(self.gate_channels // 2, self.gate_channels)
        )
        self.pool_types = pool_types

    def forward(self, x):
        x = x.permute((0, 2, 1, 3, 4))
        channel_att_sum = True
        for pool_type in self.pool_types:
            if pool_type == 'avg':
                avg_pool = nn.AvgPool3d((x.size(2), x.size(3), x.size(4)), stride=(x.size(2), x.size(3), x.size(4)))
                x_mid = avg_pool(x)
                channel_att_raw = self.mlp(x_mid)
            elif pool_type == 'max':
                max_pool = nn.MaxPool3d((x.size(2), x.size(3), x.size(4)), stride=(x.size(2), x.size(3), x.size(4)))
                x_mid = max_pool(x)
                channel_att_raw = self.mlp(x_mid)

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = torch.sigmoid(channel_att_raw).unsqueeze(2).unsqueeze(3).unsqueeze(4).expand_as(x)

        return x * scale


class ChannelGate_layer4(nn.Module):
    def __init__(self, pool_types=['avg', 'max']):
        super(ChannelGate_layer4, self).__init__()
        self.gate_channels = 2
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.gate_channels, self.gate_channels),
        )
        self.pool_types = pool_types

    def forward(self, x):
        x = x.permute((0, 2, 1, 3, 4))
        channel_att_sum = True
        for pool_type in self.pool_types:
            if pool_type == 'avg':
                avg_pool = nn.AvgPool3d((x.size(2), x.size(3), x.size(4)), stride=(x.size(2), x.size(3), x.size(4)))
                x_mid = avg_pool(x)
                channel_att_raw = self.mlp(x_mid)
            elif pool_type == 'max':
                max_pool = nn.MaxPool3d((x.size(2), x.size(3), x.size(4)), stride=(x.size(2), x.size(3), x.size(4)))
                x_mid = max_pool(x)
                channel_att_raw = self.mlp(x_mid)

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = torch.sigmoid(channel_att_raw).unsqueeze(2).unsqueeze(3).unsqueeze(4).expand_as(x)

        return x * scale


class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat((torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1)


class spatial_attention(nn.Module):
    def __init__(self):
        super(spatial_attention, self).__init__()
        self.compress = ChannelPool()
        self.spatial_conv = nn.Conv3d(in_channels=2, out_channels=1, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm3d(1, eps=1e-5, momentum=0.01, affine=True)

    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial_conv(x_compress)
        x_out = torch.sigmoid(x_out)
        x_out = x_out.squeeze()
        return x_out


class BasicBlock_2d(nn.Module):
    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(BasicBlock_2d, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock_2d only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock_2d")
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class Bottleneck_2d(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion: int = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        pretrained_2d = True
    ):
        super(Bottleneck_2d, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Conv3DSimple(nn.Conv3d):
    def __init__(self,
                 in_planes,
                 out_planes,
                 midplanes=None,
                 stride=1,
                 padding=1):

        super(Conv3DSimple, self).__init__(
            in_channels=in_planes,
            out_channels=out_planes,
            kernel_size=(3, 3, 3),
            stride=stride,
            padding=padding,
            bias=False)

    @staticmethod
    def get_downsample_stride(stride):
        return stride, stride, stride
        
class Conv2Plus1D(nn.Sequential):

    def __init__(self,
                 in_planes,
                 out_planes,
                 midplanes,
                 stride=1,
                 padding=1):
        super(Conv2Plus1D, self).__init__(
            nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3),
                      stride=(1, stride, stride), padding=(0, padding, padding),
                      bias=False),
            nn.BatchNorm3d(midplanes),
            nn.ReLU(inplace=True),
            nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1),
                      stride=(stride, 1, 1), padding=(padding, 0, 0),
                      bias=False))

    @staticmethod
    def get_downsample_stride(stride):
        return stride, stride, stride


class BasicBlock_3d(nn.Module):

    expansion = 1

    def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):
        midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)

        super(BasicBlock_3d, self).__init__()
        self.conv1 = nn.Sequential(
            conv_builder(inplanes, planes, midplanes, stride),
            nn.BatchNorm3d(planes),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            conv_builder(planes, planes, midplanes),
            nn.BatchNorm3d(planes)
        )
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.conv2(out)
        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out
        
        
class Bottleneck_3d(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None):

        super(Bottleneck_3d, self).__init__()
        midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)

        # 1x1x1
        self.conv1 = nn.Sequential(
            nn.Conv3d(inplanes, planes, kernel_size=1, bias=False),
            nn.BatchNorm3d(planes),
            nn.ReLU(inplace=True)
        )
        # Second kernel
        self.conv2 = nn.Sequential(
            conv_builder(planes, planes, midplanes, stride),
            nn.BatchNorm3d(planes),
            nn.ReLU(inplace=True)
        )
        # 1x1x1
        self.conv3 = nn.Sequential(
            nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),
            nn.BatchNorm3d(planes * self.expansion)
        )
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out
        
        
class BasicStem(nn.Sequential):
    """The default conv-batchnorm-relu stem
    """
    def __init__(self):
        super(BasicStem, self).__init__(
            nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2),
                      padding=(1, 3, 3), bias=False),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True))


class R2Plus1dStem(nn.Sequential):
    """R(2+1)D stem is different than the default one as it uses separated 3D convolution
    """
    def __init__(self):
        super(R2Plus1dStem, self).__init__(
            nn.Conv3d(3, 45, kernel_size=(1, 7, 7),
                      stride=(1, 2, 2), padding=(0, 3, 3),
                      bias=False),
            nn.BatchNorm3d(45),
            nn.ReLU(inplace=True),
            nn.Conv3d(45, 64, kernel_size=(3, 1, 1),
                      stride=(1, 1, 1), padding=(1, 0, 0),
                      bias=False),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True))


class res_3_2d_net(nn.Module):
    """
    The res_3_2d_net combines resnet3d and resnet2d for model fusion.
    """
    # r3d
    # def __init__(self, arch_3d='r3d_18', arch_2d='resnet18', pretrained_3d=True, pretrained_2d=True, progress= True,
    #              block_3d=BasicBlock_3d, block_2d=BasicBlock_2d, conv_makers=[Conv3DSimple] * 4,
    #              layers=[2, 2, 2, 2], stem=BasicStem, num_classes=400,
    #              zero_init_residual_3d=False, zero_init_residual_2d=False,
    #              groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None):

    # r2plus1d_18
    def __init__(self, arch_3d='r2plus1d_18', arch_2d='resnet18', pretrained_3d=True, pretrained_2d=True, progress=True,
                 block_3d=BasicBlock_3d, block_2d=BasicBlock_2d, conv_makers=[Conv2Plus1D] * 4,
                 layers=[2, 2, 2, 2], stem=R2Plus1dStem, num_classes=400,
                 zero_init_residual_3d=False, zero_init_residual_2d=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None):
                 
        super(res_3_2d_net, self).__init__()

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer_2d = norm_layer

        self.inplanes_2d = 64
        self.dilation_2d = 1
        self.inplanes_3d = 64

        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))

        self.groups_2d = groups
        self.base_width_2d = width_per_group
        self.conv1_2d = nn.Conv2d(3, self.inplanes_2d, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1_2d = norm_layer(self.inplanes_2d)
        self.relu_2d = nn.ReLU(inplace=True)
        self.maxpool_2d = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1_2d = self._make_layer_2d(block_2d, 64, layers[0])
        self.layer2_2d = self._make_layer_2d(block_2d, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3_2d = self._make_layer_2d(block_2d, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4_2d = self._make_layer_2d(block_2d, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool_2d = nn.AdaptiveAvgPool2d((1, 1))

        self.stem_3d = stem()
        self.layer1_3d = self._make_layer_3d(block_3d, conv_makers[0], 64, layers[0], stride=1)
        self.layer2_3d = self._make_layer_3d(block_3d, conv_makers[1], 128, layers[1], stride=2)
        self.layer3_3d = self._make_layer_3d(block_3d, conv_makers[2], 256, layers[2], stride=2)
        self.layer4_3d = self._make_layer_3d(block_3d, conv_makers[3], 512, layers[3], stride=2)

        self.avgpool_3d = nn.AdaptiveAvgPool3d((1, 1, 1))

        self.spatial_pathway1 = spatial_attention()
        self.spatial_pathway2 = spatial_attention()
        self.spatial_pathway3 = spatial_attention()
        self.spatial_pathway4 = spatial_attention()

        self.pool_types = ['avg', 'max']
        self.ChannelGate1 = ChannelGate_layer1(self.pool_types)
        self.ChannelGate2 = ChannelGate_layer2(self.pool_types)
        self.ChannelGate3 = ChannelGate_layer3(self.pool_types)
        self.ChannelGate4 = ChannelGate_layer4(self.pool_types)

        # init weights
        self._initialize_weights_3d()

        if zero_init_residual_3d:
            for m in self.modules():
                if isinstance(m, Bottleneck_3d):
                    nn.init.constant_(m.bn3.weight, 0)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual_2d:
            for m in self.modules():
                if isinstance(m, Bottleneck_2d):
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
                elif isinstance(m, BasicBlock_2d):
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]

        if pretrained_3d:
            state_dict_3d = load_state_dict_from_url(model_urls[arch_3d], progress=progress)
            keys_org = [*state_dict_3d.keys()]
            state_dict_chg_3d = state_dict_3d

            # Generate new key names which equals to the model's key
            new_key = []
            for keys in state_dict_chg_3d:
                split_key = keys.split('.')
                split_key[0] = split_key[0] + '_3d'
                join_key = '.'.join(split_key)
                new_key.append(join_key)

            # update key names for the pretrained model
            for ind in range(len(keys_org)):
                state_dict_chg_3d[new_key[ind]] = state_dict_chg_3d.pop(keys_org[ind])

            s_dict = self.state_dict()
            pretrained_dict_3d = {k: v for k, v in state_dict_chg_3d.items() if k in s_dict}
            s_dict.update(pretrained_dict_3d)
            self.load_state_dict(s_dict)

        if pretrained_2d:
            state_dict_2d = load_state_dict_from_url(model_urls[arch_2d], progress=progress)
            keys_org_2d = list(state_dict_2d.keys())
            state_dict_chg_2d = state_dict_2d
            # Generate new key names which equals to the model's key
            new_key_2d = []
            for keys in state_dict_chg_2d:
                split_key = keys.split('.')
                split_key[0] = split_key[0] + '_2d'
                join_key = '.'.join(split_key)
                new_key_2d.append(join_key)

            # update key names for the pretrained model
            for ind in range(len(keys_org_2d)):
                state_dict_chg_2d[new_key_2d[ind]] = state_dict_chg_2d.pop(keys_org_2d[ind])

            s_dict = self.state_dict()
            pretrained_dict_2d = {k: v for k, v in state_dict_chg_2d.items() if k in s_dict}
            s_dict.update(pretrained_dict_2d)
            self.load_state_dict(s_dict)

    def forward(self, x_vid, x_im):
    
        x_vid = self.stem_3d(x_vid)

        x_im = self.conv1_2d(x_im)
        x_im = self.bn1_2d(x_im)
        x_im = self.relu_2d(x_im)
        x_im = self.maxpool_2d(x_im)

        # Layer 1 Fusion
        x_vid = self.layer1_3d(x_vid)
        x_im = self.layer1_2d(x_im)
        x_att = self.ChannelGate1(x_vid)
        x_att = self.spatial_pathway1(x_att)
        x_im = x_im*x_att

        # Layer 2 Fusion
        x_vid = self.layer2_3d(x_vid)
        x_im = self.layer2_2d(x_im)
        x_att = self.ChannelGate2(x_vid)
        x_att = self.spatial_pathway2(x_att)
        x_im = x_im*x_att

        # Layer 3 Fusion
        x_vid = self.layer3_3d(x_vid)
        x_im = self.layer3_2d(x_im)
        x_att = self.ChannelGate3(x_vid)
        x_att = self.spatial_pathway3(x_att)
        x_im = x_im*x_att

        # Layer 4 Fusion
        x_vid = self.layer4_3d(x_vid)
        x_im = self.layer4_2d(x_im)
        x_att = self.ChannelGate4(x_vid)
        x_att = self.spatial_pathway4(x_att)
        x_im = x_im*x_att

        #
        x_vid = self.avgpool_3d(x_vid)
        x_im = self.avgpool_2d(x_im)

        x_vid = x_vid.flatten(1)
        x_im = x_im.flatten(1)

        x_vid = x_vid.view(x_vid.size(0), -1)
        x_im = x_im.view(x_im.size(0), -1)

        return x_vid, x_im

    def _make_layer_2d(self, block: Type[Union[BasicBlock_2d, Bottleneck_2d]], planes: int, blocks: int,
                    stride: int = 1, dilate: bool = False):
        norm_layer = self._norm_layer_2d
        downsample = None
        previous_dilation = self.dilation_2d
        if dilate:
            self.dilation_2d *= stride
            stride = 1
        if stride != 1 or self.inplanes_2d != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes_2d, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes_2d, planes, stride, downsample, self.groups_2d,
                            self.base_width_2d, previous_dilation, norm_layer))
        self.inplanes_2d = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes_2d, planes, groups=self.groups_2d,
                                base_width=self.base_width_2d, dilation=self.dilation_2d,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _make_layer_3d(self, block, conv_builder, planes, blocks, stride=1):
        downsample = None

        if stride != 1 or self.inplanes_3d != planes * block.expansion:
            ds_stride = conv_builder.get_downsample_stride(stride)
            downsample = nn.Sequential(
                nn.Conv3d(self.inplanes_3d, planes * block.expansion,
                          kernel_size=1, stride=ds_stride, bias=False),
                nn.BatchNorm3d(planes * block.expansion)
            )
        layers = []
        layers.append(block(self.inplanes_3d, planes, conv_builder, stride, downsample))

        self.inplanes_3d = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes_3d, planes, conv_builder))

        return nn.Sequential(*layers)

    def _initialize_weights_3d(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out',
                                        nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


def get_1x_lr_params(model):
    """
    This generator returns all the parameters for conv and two fc layers of the net.
    """
    b = [model.conv1_2d, model.bn1_2d, model.layer1_2d, model.layer2_2d, model.layer3_2d, model.layer4_2d,
         model.avgpool_2d, model.stem_3d, model.layer1_3d, model.layer2_3d, model.layer3_3d, model.layer4_3d, model.avgpool_3d]
    for i in range(len(b)):
        for k in b[i].parameters():
            if k.requires_grad:
                yield k

def get_10x_lr_params(model):
    """
    This generator returns all the parameters for the last fc layer of the net.
    """
    b = [model.layer1_c1, model.layer1_c2, model.layer1_c3, model.layer2_c1, model.layer2_c2, model.layer2_c3,
         model.layer3_c1, model.layer3_c2, model.layer3_c3, model.layer4_c1]
    # b = [model.lstm1]
    for j in range(len(b)):
        for k in b[j].parameters():
            if k.requires_grad:
                yield k


if __name__ == "__main__":
    pretrained_3d = True
    pretrained_2d = True
    progress = True

    inputs1 = torch.rand(2, 3, 16, 112, 112)
    inputs2 = torch.rand(2, 3, 224, 224)

    net = res_3_2d_net()

    output1, output2 = net.forward(inputs1, inputs2)
    print('output 3d size:', output1.size(), 'output 2d size:', output2.size())
    
    parameters = filter(lambda p: p.requires_grad, net.parameters())
    parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
    print('Trainable Parameters: %.3fM' % parameters)

Downloading: "https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth" to /root/.cache/torch/hub/checkpoints/r2plus1d_18-91a641e6.pth


  0%|          | 0.00/120M [00:00<?, ?B/s]

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

output 3d size: torch.Size([2, 512]) output 2d size: torch.Size([2, 512])
Trainable Parameters: 42.477M


Multi-task learning wrapper

In [None]:
"""
Generate loss wrapper for multi-task learning
"""
import torch
from torch import nn, optim
from torch.autograd import Variable
import numpy as np

class MultiTaskLossWrapper(nn.Module):
    def __init__(self, task_num):
        super(MultiTaskLossWrapper, self).__init__()
        self.task_num = task_num
        self.log_vars = nn.Parameter(torch.zeros(task_num))
        print('loss log_vars', self.log_vars)
        self.criterion1 = nn.CrossEntropyLoss()

    def forward(self, loss_mid_tot, loss_long, intent_lab, predict_intent_score, predict_activity_score, lab_seq_final):

        loss0 = loss_mid_tot
        loss1 = loss_long
        loss2 = self.criterion1(predict_activity_score, lab_seq_final)
        loss3 = self.criterion1(predict_intent_score, intent_lab)

        precision0 = torch.exp(-self.log_vars[0])
        loss0_p = precision0*loss0 + self.log_vars[0]

        precision1 = torch.exp(-self.log_vars[1])
        loss1_p = precision1*loss1 + self.log_vars[1]
        #
        precision2 = torch.exp(-self.log_vars[2])
        loss2_p = precision2*loss2 + self.log_vars[2]
        
        precision3 = torch.exp(-self.log_vars[3])
        loss3_p = precision3*loss3 + self.log_vars[3]
        
        print('loss_mid', loss_mid_tot)
        print('loss_long', loss_long)
        print('log_vars', self.log_vars)

        tot_loss = loss0_p + loss1_p + loss2_p + loss3_p

        return tot_loss

In [None]:
import timeit
from datetime import datetime
import socket
import os
import glob
from tqdm import tqdm

import torch
from PIL import Image
from tensorboardX import SummaryWriter
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.autograd import Variable

import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
from torchvision.transforms import transforms

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
print("Device being used:", device)


nEpochs = 80            # Number of epochs for training
resume_epoch = 0        # Default is 0, change if want to resume
useTest = True          # See evolution of the test set when training
nTestInterval = 1       # Run on test set every nTestInterval epochs
snapshot = 100       # Store a model every snapshot epochs
lr = 1e-3               # Learning rate
# pre_h = 1

stop_time_init = str(timeit.default_timer())

dataset = 'face'  # Options: hmdb51 or ucf101

if dataset == 'face':
    num_classes = 5             # long-term activity
    num_activtites = 4          # mid-term activity
elif dataset == 'gtea':
    num_classes = 7
    num_activtites = 79
elif dataset == 'finegym':
    num_classes = 4
    num_activtites = 291
    
num_tasks = 4
require_mtl = True

save_dir_root = '/content/drive/MyDrive/carknows'
print('save_dir_root is:', save_dir_root)

if resume_epoch != 0:
    runs = sorted(glob.glob(os.path.join(save_dir_root, 'run', 'run_*')))
    run_id = int(runs[-1].split('_')[-1]) if runs else 0
else:
    runs = sorted(glob.glob(os.path.join(save_dir_root, 'run', 'run_*')))
    run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0

save_dir = os.path.join(save_dir_root, 'run', 'run_' + str(run_id))
modelName = '2classes'   
saveName = modelName + '-' + dataset

rand_flg = True

print('savename {}'.format(saveName))

def accuracy(output, target, topk=(1,3)):
    """Computes the precision@k for the specified values of k"""

    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = (pred == target.unsqueeze(dim=0)).expand_as(pred)

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(1.0 / batch_size))
    return res


def process_sig_img(clip_img_final, org_clip_img_size):
    process_clip = torch.empty([org_clip_img_size[0],3,224,224])
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    for ind in range(org_clip_img_size[0]):
        im_tmp = clip_img_final[ind, :, :, :]
        im_tmp = im_tmp.permute(1, 2, 0)
        im_tmp = im_tmp.numpy()
        img = Image.fromarray(np.uint8(im_tmp)).convert('RGB')
        img_tensor = preprocess(img)
        process_clip[ind,:,:,:] = img_tensor

    clip_img_final = process_clip

    return clip_img_final


def generate_last_prediction_data(clip_i, inputs, inputs_resnet, seq_labels):
    if rand_flg:
        rand_ind = np.random.randint(0, 15)
    else:
        rand_ind = 8

    clip_i_final = clip_i + 16
    clip_seg_final = inputs[:, :, clip_i:clip_i_final, :, :]
    clip_img_final = inputs_resnet[:, clip_i + rand_ind, :, :, :]
    org_clip_img_size = clip_img_final.size()

    clip_img_final = process_sig_img(clip_img_final, org_clip_img_size)
    lab_seq_sub_final = seq_labels[:, :, clip_i:clip_i_final]

    clip_seg_final = Variable(clip_seg_final, requires_grad=False).to(device)
    clip_img_final = Variable(clip_img_final, requires_grad=False).to(device)

    lab_seq_sub_shape = lab_seq_sub_final.shape
    seqeunce_label = np.empty([lab_seq_sub_shape[0]])
    seqeunce_percent = np.empty([lab_seq_sub_shape[0]])

    for c_ind in range(lab_seq_sub_shape[0]):
        sub_seq = lab_seq_sub_final[c_ind, :, :]
        sub_seq = sub_seq.squeeze()
        sub_seq_np = sub_seq.numpy()
        sub_count_np = Counter(sub_seq_np)
        value, cnt = sub_count_np.most_common()[0]
        seqeunce_label[c_ind] = value - 1
        seqeunce_percent[c_ind] = cnt / 16    # sequence_percent not used yet 

    seqeunce_label = torch.from_numpy(seqeunce_label)
    seqeunce_label = seqeunce_label.clone().detach()
    seqeunce_label = Variable(seqeunce_label, requires_grad=False).to(device)
    seqeunce_label = seqeunce_label.long()

    seqeunce_percent = 0        # no meaning for current study

    return clip_seg_final, clip_img_final, seqeunce_label, seqeunce_percent


def wrtieresults(testing_acc_long, testing_acc_activity, testing_acc_long_pred, testing_acc_activity_pred):
    txtfile_name = 'results.txt'
    stop_time = str(timeit.default_timer())
    txtfile_name = stop_time + txtfile_name

    content = np.vstack((testing_acc_long, testing_acc_activity, testing_acc_long_pred, testing_acc_activity_pred))

    if os.path.exists(txtfile_name):
        os.remove(txtfile_name)

    with open(txtfile_name, "a+") as f:
        for i in range(4):
            f.writelines(str(content[i, :]))
            f.writelines("\n")

def wrtieresults_train(testing_acc_long, testing_acc_activity, testing_acc_long_pred, testing_acc_activity_pred):
    txtfile_name = 'results_train.txt'
    stop_time = str(timeit.default_timer())
    txtfile_name = stop_time + txtfile_name

    content = np.vstack((testing_acc_long, testing_acc_activity, testing_acc_long_pred, testing_acc_activity_pred))

    if os.path.exists(txtfile_name):
        os.remove(txtfile_name)

    with open(txtfile_name, "a+") as f:
        for i in range(4):
            f.writelines(str(content[i, :]))
            f.writelines("\n")


def wrtieresults_top_pred(testing_acc_long, testing_acc_activity, testing_acc_long_pred, testing_acc_activity_pred):
    txtfile_name = 'results_top_prediction.txt'
    stop_time = str(timeit.default_timer())
    txtfile_name = stop_time + txtfile_name

    content = np.vstack((testing_acc_long, testing_acc_activity, testing_acc_long_pred, testing_acc_activity_pred))

    print(content)
    if os.path.exists(txtfile_name):
        os.remove(txtfile_name)

    with open(txtfile_name, "a+") as f:
        for i in range(4):
            f.writelines(str(content[i, :]))
            f.writelines("\n")


def wrtieresults_top_recog(testing_acc_long, testing_acc_activity, testing_acc_long_pred, testing_acc_activity_pred):
    txtfile_name = 'results_top_recogniton.txt'
    stop_time = str(timeit.default_timer())
    txtfile_name = stop_time + txtfile_name

    content = np.vstack((testing_acc_long, testing_acc_activity, testing_acc_long_pred, testing_acc_activity_pred))

    print(content)
    if os.path.exists(txtfile_name):
        os.remove(txtfile_name)

    with open(txtfile_name, "a+") as f:
        for i in range(4):
            f.writelines(str(content[i, :]))
            f.writelines("\n")


def wrtie_confusion_results(training_intent_confuse, training_intent_lab_confuse, training_activity_confuse, training_activity_lab_confuse, txtfile_name):
    stop_time = stop_time_init
    txtfile_name1 = stop_time + '_int' + txtfile_name
    txtfile_name2 = stop_time + '_act' + txtfile_name

    content1 = np.vstack((training_intent_confuse, training_intent_lab_confuse))
    content2 = np.vstack((training_activity_confuse, training_activity_lab_confuse))

    np.savetxt(txtfile_name1, np.round(content1), delimiter=',')
    np.savetxt(txtfile_name2, np.round(content2), delimiter = ',')


def train_model(dataset=dataset, save_dir=save_dir, num_classes=num_classes, num_activities=num_activtites,  lr=lr,
                num_epochs=nEpochs, save_epoch=snapshot, useTest=useTest, test_interval=nTestInterval, pred_horizon=1):
    # pred_horizon = 3
    print('current training pred_horizon', pred_horizon)
    """
        Args:
            num_classes (int): Number of classes in the data
            num_epochs (int, optional): Number of epochs to train for.
    """
    fusion_dim = 9 - pred_horizon

    if modelName == '2classes':
        encoder_model = res_3_2d_net()
        fusion_model = Fusion_net_2(num_classes=num_activities,pred_horizon=fusion_dim)
        decoder_model = LSTM_ANNO(num_classes=num_classes)
        predict_model = prediction_net(num_activity_classes=num_activities, num_intent_classes=num_classes, fusion_ind=True, pred_horizon=fusion_dim)
        
        if require_mtl:
            mt_loss = MultiTaskLossWrapper(task_num=num_tasks)
            print('loss parameters',[*mt_loss.parameters()])
        
    else:
        print('We only implemented C3D and R2Plus1D models.')
        raise NotImplementedError

    if resume_epoch == 0:
        print("Training {} from scratch...".format(modelName))
    else:
        checkpoint = torch.load(
            os.path.join(save_dir, 'models', saveName + '_epoch-' + str(resume_epoch) + '.pth'),
            map_location=lambda storage, loc: storage)  # Load all tensors onto the CPU
        
        print("Initializing weights from: {}...".format(
            os.path.join(save_dir, 'models', saveName + '_epoch-' + str(resume_epoch) + '.pth')))
        encoder_model.load_state_dict(checkpoint['encoder_state_dict'])
        decoder_model.load_state_dict(checkpoint['decoder_state_dict'])
        predict_model.load_state_dict(checkpoint['predict_state_dict'])
        fusion_model.load_state_dict(checkpoint['fusion_state_dict'])
        mt_loss.load_state_dict(checkpoint['mt_loss'])

    print('Total params: %.2fM' % (sum(p.numel() for p in encoder_model.parameters()) / 1000000.0))

    criterion = nn.CrossEntropyLoss()  # standard crossentropy loss for classification
    criterion1 = nn.MSELoss()

    encoder_model.to(device)
    fusion_model.to(device)
    mt_loss.to(device)
    decoder_model.to(device)
    predict_model.to(device)
    
    train_params = [
                    {'params': encoder_model.parameters(), 'lr': lr / 10},
                    {'params': mt_loss.parameters(), 'lr': lr},
                    {'params': Fusion_net_2.get_1x_lr_params(fusion_model), 'lr': lr / 10},
                    {'params': fusion_model.spat_weigh, 'lr': lr },
                    {'params': fusion_model.temp_weigh, 'lr': lr },
                    {'params': prediction_net.get_1x_lr_params(predict_model), 'lr': lr / 10},
                    {'params': predict_model.intent_weigh, 'lr': lr },
                    {'params': predict_model.act_weigh, 'lr': lr },
                    {'params': decoder_model.parameters(), 'lr': lr / 10}]
    print('train_params', train_params)

    optimizer = optim.Adam(train_params, lr=lr, betas=(0.9, 0.999), weight_decay=5e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=120,
                                          gamma=0.1)  # the scheduler divides the lr by 10 every 10 epochs
    if resume_epoch != 0:
      optimizer.load_state_dict(checkpoint['opt_dict'])
      for state in optimizer.state.values():
          for k, v in state.items():
              if torch.is_tensor(v):
                  state[k] = v.cuda()                                      

    log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
    print('tensorboard log_dir', log_dir)
    writer = SummaryWriter(log_dir=log_dir)

    print('Training model on {} dataset...'.format(dataset))

    bat_size = 3
    clp_len = 150

    mean = [0.43216, 0.394666, 0.37645]
    std = [0.22803, 0.22145, 0.216989]
    resize_size = 224
    crop_size = 112

    mean_res = [0.485, 0.456, 0.406]
    std_res = [0.229, 0.224, 0.225]

    trans1 = [
        ConvertBHWCtoBCHW(),
        transforms.ConvertImageDtype(torch.float32),
    ]
    trans1.extend([
        transforms.Normalize(mean=mean, std=std),
        ConvertBCHWtoCBHW()])

    trans2 = [
        ConvertBHWCtoBCHW(),
        transforms.ConvertImageDtype(torch.float32),
    ]
    trans2.extend([
        transforms.Normalize(mean=mean_res, std=std_res)])

    transform1_t = transforms.Compose(trans1)
    transform2_t = transforms.Compose(trans2)

    train_dataloader = DataLoader(CarDataset_multi_2(dataset=dataset, split='train', clip_len=clp_len,
                                                  transform1=transform1_t, transform2=None),
                                  batch_size=bat_size, shuffle=True, num_workers=4, pin_memory=True)

    val_dataloader = DataLoader(CarDataset_multi_2(dataset=dataset, split='val', clip_len=clp_len,
                                                transform1=transform1_t, transform2=None),
                                batch_size=bat_size, num_workers=4, pin_memory=True)

    test_dataloader = DataLoader(CarDataset_multi_2(dataset=dataset, split='test', clip_len=clp_len,
                                                 transform1=transform1_t, transform2=None),
                                 batch_size=bat_size, shuffle=True, num_workers=4, pin_memory=True)

    trainval_loaders = {'train': train_dataloader, 'val': val_dataloader}
    trainval_sizes = {x: len(trainval_loaders[x].dataset) for x in ['train', 'val']}
    test_size = len(test_dataloader.dataset)

    training_loss = []
    training_acc_long = []
    training_acc_activity = []
    training_acc_long_pred = []
    training_acc_activity_pred = []

    training_acc_top1 = []
    training_acc_top3 = []
    training_acc_top1_MR = []
    training_acc_top3_MR = []
    training_acc_top1_IP = []
    training_acc_top3_IP = []
    training_acc_top1_MP = []
    training_acc_top3_MP = []

    testing_loss = []
    testing_acc_long = []
    testing_acc_activity = []
    testing_acc_long_pred = []
    testing_acc_activity_pred = []
    testing_acc_top1 = []
    testing_acc_top3 = []
    testing_acc_top1_MR = []
    testing_acc_top3_MR = []
    testing_acc_top1_IP = []
    testing_acc_top3_IP = []
    testing_acc_top1_MP = []
    testing_acc_top3_MP = []

    max_acc = 0

    for epoch in range(resume_epoch, num_epochs):
        # each epoch has a training and validation step
        print('epoch {}'.format(epoch))
        for phase in ['train']:
            start_time = timeit.default_timer()
            print('start time', start_time)

            running_loss = 0.0
            running_corrects_long = 0.0
            running_corrects_mid = 0.0
            running_corrects_long_pred = 0.0
            running_corrects_mid_pred = 0.0

            running_corrects_top1 = 0.0
            running_corrects_top3 = 0.0

            running_corrects_top1_MR = 0.0
            running_corrects_top3_MR = 0.0

            running_corrects_top1_IP = 0.0
            running_corrects_top3_IP = 0.0

            running_corrects_top1_MP = 0.0
            running_corrects_top3_MP = 0.0

            # reset the running loss and corrects
            # set model to train() or eval() mode depending on whether it is trained
            # or being validated. Primarily affects layers such as BatchNorm or Dropout.
            if phase == 'train':
                # scheduler.step() is to be called once every epoch during training
                optimizer.step()
                scheduler.step()

                encoder_model.train()
                decoder_model.train()
                predict_model.train()
                mt_loss.train()
                fusion_model.train()

            else:
                encoder_model.eval()
                decoder_model.eval()
                predict_model.eval()
                mt_loss.eval()
                fusion_model.eval()


            for inputs, inputs_resnet, labels, seq_labels in tqdm(trainval_loaders[phase]):
                # move inputs and labels to the device the training is taking place on
                labels = labels.clone().detach()
                labels = labels.long()
                labels = Variable(labels, requires_grad=False).to(device)
                print('labels', labels.size(), labels)

                batch_size, C, frames, H, W = inputs.shape

                clip_feats = torch.Tensor([]).to(device)
                clip_activity_feats = torch.Tensor([]).to(device)
                tot_activity_pred = []
                tot_seq_lab = []
                clip_ind = 0
                init_flag = True
                pre1_sum = torch.Tensor([]).to(device)
                pre3_sum = torch.Tensor([]).to(device)

                loss_mid_tot = 0
                loss_cos = 0
                loss_mid_reg = 0
                loss_mid = 0
                ind = 0

                clip_len = len(np.arange(0, frames-(17+pred_horizon*16), 16))
                pred_ind = frames-17
                for clip_i in np.arange(0, frames-(17+pred_horizon*16), 16):
                    ind += 1
                    clip_seg, clip_img, seqeunce_label, seqeunce_percent = generate_last_prediction_data(clip_i, inputs, inputs_resnet, seq_labels)

                    clip_feats_temp, sig_feats_temp = encoder_model(clip_seg, clip_img)
                    clip_feats_int, activity_pred, activity_output, [spat_weigh, temp_weigh] = fusion_model(clip_feats_temp, sig_feats_temp)

                    loss_mid = criterion(activity_pred, seqeunce_label)
                    loss_mid_tot += loss_mid

                    prec1, prec3 = accuracy(activity_pred.data, seqeunce_label.data, topk=(1, 3))

                    pre1_sum = torch.cat([pre1_sum, prec1])
                    pre3_sum = torch.cat([pre3_sum, prec3])

                    if init_flag:
                        probs_mid_sub = nn.Softmax(dim=1)(activity_pred)
                        preds_mid_sub = torch.max(probs_mid_sub, 1)[1]
                        seq_labs_mid = seqeunce_label
                        init_flag = False
                    else:
                        probs_mid_sub = nn.Softmax(dim=1)(activity_pred)
                        preds_mid_sub_s = torch.max(probs_mid_sub, 1)[1]
                        preds_mid_sub = torch.cat((preds_mid_sub, preds_mid_sub_s))
                        seq_labs_mid = torch.cat((seq_labs_mid, seqeunce_label))

                    clip_feats = torch.cat((clip_feats, clip_feats_int), 1)
                    clip_activity_feats = torch.cat((clip_activity_feats, activity_output), 1)

                print('loss_mid_tot', loss_mid_tot, loss_mid_tot.size())

                prec1_MR = torch.mean(pre1_sum)
                prec3_MR = torch.mean(pre3_sum)

                clip_seg_final, clip_img_final, lab_seq_final, lab_percent_final = generate_last_prediction_data(pred_ind, inputs, inputs_resnet, seq_labels)
 
                pred_final_score, gru_output = decoder_model(clip_feats)
                predict_intent_score, predict_activity_score, [int_weigh, act_weigh] = predict_model(clip_activity_feats, gru_output)

                print('fusion module spaial:', spat_weigh, 'fusion module temporal:', temp_weigh)
                print('predict module intent:', int_weigh, 'predict module activity', act_weigh)

                probs_pred_long = nn.Softmax(dim=1)(predict_intent_score)
                preds_pred_long = torch.max(probs_pred_long, 1)[1]

                probs_pred_mid = nn.Softmax(dim=1)(predict_activity_score)
                preds_pred_mid = torch.max(probs_pred_mid, 1)[1]

                probs = nn.Softmax(dim=1)(pred_final_score)
                preds = torch.max(probs, 1)[1]

                preds_mid = preds_mid_sub
                seq_labs = seq_labs_mid

                prec1_LI, prec3_LI = accuracy(pred_final_score.data, labels.data, topk=(1, 3))
                prec1_LP, prec3_LP = accuracy(predict_intent_score.data, labels.data, topk=(1, 3))
                prec1_MP, prec3_MP = accuracy(predict_activity_score.data, lab_seq_final.data, topk=(1, 3))

                if require_mtl:
                    loss_long = criterion(pred_final_score, labels)
                    loss_tot = mt_loss(loss_mid_tot, loss_long, labels, predict_intent_score,
                                       predict_activity_score, lab_seq_final)
                    print('loss_tot', loss_tot, loss_tot.size())
                else:
                    loss_mid = criterion(activity_pred, seqeunce_label)
                    print('loss_mid', loss_mid)
                    loss_long = criterion(pred_final_score, labels)
                    print('loss_long', loss_long)
                    loss_tot = loss_long + (loss_mid)
                    print('loss_tot', loss_tot)

                if phase == 'train':
                    optimizer.zero_grad()
                    loss_tot.backward()
                    optimizer.step()

                running_loss += loss_tot.item() * inputs.size(0)
                running_corrects_long += torch.sum(preds == labels.data)
                running_corrects_mid += torch.sum(preds_mid == seq_labs.data)

                running_corrects_long_pred += torch.sum(preds_pred_long == labels.data)
                running_corrects_mid_pred += torch.sum(preds_pred_mid == lab_seq_final.data)

                running_corrects_top1 += torch.sum(prec1_LI)
                running_corrects_top3 += torch.sum(prec3_LI)

                running_corrects_top1_MR += torch.sum(prec1_MR)
                running_corrects_top3_MR += torch.sum(prec3_MR)

                running_corrects_top1_IP += torch.sum(prec1_LP)
                running_corrects_top3_IP += torch.sum(prec3_LP)

                running_corrects_top1_MP += torch.sum(prec1_MP)
                running_corrects_top3_MP += torch.sum(prec3_MP)

            train_epoch_loss = running_loss / trainval_sizes[phase]
            train_epoch_acc_long = running_corrects_long.double() / trainval_sizes[phase]
            train_epoch_acc_mid = running_corrects_mid.double() / (trainval_sizes[phase]*(clip_len))
            train_epoch_acc_long_pred = running_corrects_long_pred.double() / trainval_sizes[phase]
            train_epoch_acc_mid_pred = running_corrects_mid_pred.double() / trainval_sizes[phase]

            train_epoch_acc_top1 = running_corrects_top1.double() * bat_size / trainval_sizes[phase]
            train_epoch_acc_top3 = running_corrects_top3.double() * bat_size / trainval_sizes[phase]

            train_epoch_acc_top1_MR = running_corrects_top1_MR.double() * bat_size / trainval_sizes[phase]
            train_epoch_acc_top3_MR = running_corrects_top3_MR.double() * bat_size / trainval_sizes[phase]

            train_epoch_acc_top1_IP = running_corrects_top1_IP.double() * bat_size / trainval_sizes[phase]
            train_epoch_acc_top3_IP = running_corrects_top3_IP.double() * bat_size / trainval_sizes[phase]

            train_epoch_acc_top1_MP = running_corrects_top1_MP.double() * bat_size / trainval_sizes[phase]
            train_epoch_acc_top3_MP = running_corrects_top3_MP.double() * bat_size / trainval_sizes[phase]

            print('train_epoch_acc_top1', train_epoch_acc_top1,' train_epoch_acc_top3', train_epoch_acc_top3)

            train_epoch_acc_cpu_long = train_epoch_acc_long.data.cpu().numpy()
            train_epoch_acc_cpu_mid = train_epoch_acc_mid.data.cpu().numpy()
            train_epoch_acc_cpu_long_pred = train_epoch_acc_long_pred.cpu().numpy()
            train_epoch_acc_cpu_mid_pred = train_epoch_acc_mid_pred.cpu().numpy()

            train_epoch_acc_cpu_top1 = train_epoch_acc_top1.cpu().numpy()
            train_epoch_acc_cpu_top3 = train_epoch_acc_top3.cpu().numpy()

            train_epoch_acc_cpu_top1_MR = train_epoch_acc_top1_MR.cpu().numpy()
            train_epoch_acc_cpu_top3_MR = train_epoch_acc_top3_MR.cpu().numpy()

            train_epoch_acc_cpu_top1_IP = train_epoch_acc_top1_IP.cpu().numpy()
            train_epoch_acc_cpu_top3_IP = train_epoch_acc_top3_IP.cpu().numpy()

            train_epoch_acc_cpu_top1_MP = train_epoch_acc_top1_MP.cpu().numpy()
            train_epoch_acc_cpu_top3_MP = train_epoch_acc_top3_MP.cpu().numpy()

            training_loss = np.append(training_loss, train_epoch_loss)
            training_acc_long = np.append(training_acc_long, train_epoch_acc_cpu_long)
            training_acc_activity = np.append(training_acc_activity, train_epoch_acc_cpu_mid)
            training_acc_long_pred = np.append(training_acc_long_pred, train_epoch_acc_cpu_long_pred)
            training_acc_activity_pred = np.append(training_acc_activity_pred, train_epoch_acc_cpu_mid_pred)

            training_acc_top1 = np.append(training_acc_top1, train_epoch_acc_cpu_top1)
            training_acc_top3 = np.append(training_acc_top3, train_epoch_acc_cpu_top3)
            training_acc_top1_MR = np.append(training_acc_top1_MR, train_epoch_acc_cpu_top1_MR)
            training_acc_top3_MR = np.append(training_acc_top3_MR, train_epoch_acc_cpu_top3_MR)

            training_acc_top1_IP = np.append(training_acc_top1_IP, train_epoch_acc_cpu_top1_IP)
            training_acc_top3_IP = np.append(training_acc_top3_IP, train_epoch_acc_cpu_top3_IP)
            training_acc_top1_MP = np.append(training_acc_top1_MP, train_epoch_acc_cpu_top1_MP)
            training_acc_top3_MP = np.append(training_acc_top3_MP, train_epoch_acc_cpu_top3_MP)

            writer.add_scalar('data/train_loss_epoch', train_epoch_loss, epoch)
            writer.add_scalar('data/train_acc_long_epoch', train_epoch_acc_long, epoch)
            writer.add_scalar('data/train_acc_mid_epoch', train_epoch_acc_mid, epoch)
            writer.add_scalar('data/train_acc_long_pred_epoch', train_epoch_acc_long_pred, epoch)
            writer.add_scalar('data/train_acc_mid_pred_epoch', train_epoch_acc_mid_pred, epoch)

            writer.add_scalar('data/train_top1_epoch', train_epoch_acc_cpu_top1, epoch)
            writer.add_scalar('data/train_top3_epoch', train_epoch_acc_cpu_top3, epoch)
            writer.add_scalar('data/train_epoch_acc_cpu_top1_MR', train_epoch_acc_cpu_top1_MR, epoch)
            writer.add_scalar('data/train_epoch_acc_cpu_top3_MR', train_epoch_acc_cpu_top3_MR, epoch)
            writer.add_scalar('data/train_epoch_acc_cpu_top1_IP', train_epoch_acc_cpu_top1_IP, epoch)
            writer.add_scalar('data/train_epoch_acc_cpu_top3_IP', train_epoch_acc_cpu_top3_IP, epoch)
            writer.add_scalar('data/train_epoch_acc_cpu_top1_MP', train_epoch_acc_cpu_top1_MP, epoch)
            writer.add_scalar('data/train_epoch_acc_cpu_top3_MP', train_epoch_acc_cpu_top3_MP, epoch)

            writer.flush()

            print("[{}] Epoch: {}/{} Loss: {} Intent Acc: {} Activity Acc: {} Pred Intent Acc: {} Pred Activity Acc: {}"
                  .format(phase, epoch + 1, nEpochs, train_epoch_loss, train_epoch_acc_long, train_epoch_acc_mid,
                          train_epoch_acc_long_pred, train_epoch_acc_mid_pred))

            stop_time = timeit.default_timer()
            print("Execution time: " + str(stop_time - start_time) + "\n")

            if epoch % save_epoch == (save_epoch - 1):
                torch.save({
                    'epoch': epoch + 1,
                    'encoder_state_dict': encoder_model.state_dict(),
                    'decoder_state_dict': decoder_model.state_dict(),
                    'fusion_state_dict': fusion_model.state_dict(),
                    'predict_state_dict': predict_model.state_dict(),
                    'mt_loss': mt_loss.state_dict(),
                    'opt_dict': optimizer.state_dict(),
                }, os.path.join(save_dir, 'models', saveName + '_epoch-' + str(epoch) + '.pth'))
                print("Save model at {}\n".format(os.path.join(save_dir, 'models', saveName + '_epoch-' + str(epoch) + '.pth')))

        if useTest and epoch % test_interval == (test_interval - 1):
            encoder_model.eval()
            mt_loss.eval()
            fusion_model.eval()
            decoder_model.eval()
            predict_model.eval()

            start_time = timeit.default_timer()

            running_loss = 0.0
            running_corrects_long = 0.0
            running_corrects_mid = 0.0
            running_corrects_long_pred = 0.0
            running_corrects_mid_pred = 0.0
            
            running_corrects_top1 = 0.0
            running_corrects_top3 = 0.0
            running_corrects_top1_MR = 0.0
            running_corrects_top3_MR = 0.0
            running_corrects_top1_IP = 0.0
            running_corrects_top3_IP = 0.0
            running_corrects_top1_MP = 0.0
            running_corrects_top3_MP = 0.0

            running_recog_intent = []
            running_recog_label = []

            running_recog_activity = []
            running_recog_seqlab = []

            running_pred_intent = []
            running_pred_label = []

            running_pred_activity = []
            running_pred_seqlab = []

            print('---------Testing-------------------')
            for inputs, inputs_resnet, labels, seq_labels in tqdm(test_dataloader):
                inputs = inputs.to(device)
                labels = labels.to(device, dtype=torch.long)
                with torch.no_grad():
                    batch_size, C, frames, H, W = inputs.shape
                    clip_feats = torch.Tensor([]).to(device)
                    clip_activity_feats = torch.Tensor([]).to(device)
                    
                    pre1_sum = torch.Tensor([]).to(device)
                    pre3_sum = torch.Tensor([]).to(device)
                    
                    loss_mid_tot = 0
                    loss_mid_reg = 0
                    loss_cos = 0

                    clip_len = len(np.arange(0, frames - (17 + pred_horizon * 16), 16))
                    pred_ind = frames - 17

                    init_flag_test = True
                    for clip_i in np.arange(0, frames - (17+pred_horizon*16), 16):

                        clip_seg, clip_img, seqeunce_label, seqeunce_percent = generate_last_prediction_data(clip_i, inputs,
                                                                                           inputs_resnet, seq_labels)

                        clip_feats_temp, sig_feats_temp = encoder_model(clip_seg, clip_img)
                       
                        clip_feats_int, activity_pred, activity_output, [spat_weigh, temp_weigh] = fusion_model(clip_feats_temp, sig_feats_temp)

                        loss_mid = criterion(activity_pred, seqeunce_label)
                        prec1, prec3 = accuracy(activity_pred.data, seqeunce_label.data, topk=(1, 3))

                        loss_mid_tot += loss_mid

                        pre1_sum = torch.cat([pre1_sum, prec1])
                        pre3_sum = torch.cat([pre3_sum, prec3])

                        if init_flag_test:
                            probs_mid_sub = nn.Softmax(dim=1)(activity_pred)
                            preds_mid_sub = torch.max(probs_mid_sub, 1)[1]
                            seq_labs_mid = seqeunce_label
                            init_flag_test = False
                        else:
                            probs_mid_sub = nn.Softmax(dim=1)(activity_pred)
                            preds_mid_sub_s = torch.max(probs_mid_sub, 1)[1]
                            preds_mid_sub = torch.cat((preds_mid_sub, preds_mid_sub_s))
                            seq_labs_mid = torch.cat((seq_labs_mid, seqeunce_label))

                        clip_feats = torch.cat((clip_feats, clip_feats_int), 1)
                        clip_activity_feats = torch.cat((clip_activity_feats, activity_output), 1)

                prec1_MR = torch.mean(pre1_sum)
                prec3_MR = torch.mean(pre3_sum)

                clip_seg_final, clip_img_final, lab_seq_final, lab_percent_final = generate_last_prediction_data(pred_ind,inputs,inputs_resnet,seq_labels)

                preds_mid = preds_mid_sub
                seq_labs = seq_labs_mid

                pred_final_score, gru_output = decoder_model(clip_feats)
                predict_intent_score, predict_activity_score, [int_weigh, act_weigh] = predict_model(clip_activity_feats, gru_output)

                probs_pred_long = nn.Softmax(dim=1)(predict_intent_score)
                preds_pred_long = torch.max(probs_pred_long, 1)[1]

                probs_pred_mid = nn.Softmax(dim=1)(predict_activity_score)
                preds_pred_mid = torch.max(probs_pred_mid, 1)[1]

                probs = nn.Softmax(dim=1)(pred_final_score)
                preds = torch.max(probs, 1)[1]

                prec1_LI, prec3_LI = accuracy(pred_final_score.data, labels.data, topk=(1, 3))
                prec1_LP, prec3_LP = accuracy(predict_intent_score.data, labels.data, topk=(1, 3))
                prec1_MP, prec3_MP = accuracy(predict_activity_score.data, lab_seq_final.data, topk=(1, 3))
                print('prec1_MR', prec1_MR,'prec3_MR',prec3_MR,'prec1_LI',prec1_LI,'prec3_LI',prec3_LI)

                # Recognize intent
                recog_tmp_intent = preds.clone().detach().cpu().numpy()
                recog_lab_intent = labels.clone().detach().cpu().numpy()
                # Recognize activity
                recog_tmp_activity = preds_mid.clone().detach().cpu().numpy()
                lab_seq_act_recog = seq_labs.clone().detach().cpu().numpy()
                # predict intent
                preds_tmp_intent = preds_pred_long.clone().detach().cpu().numpy()
                lab_tmp_intent = recog_lab_intent
                # predict activity
                preds_tmp_act = preds_pred_mid.clone().detach().cpu().numpy()
                lab_seq_act_final = lab_seq_final.data.clone().detach().cpu().numpy()

                if require_mtl:
                    loss_long = criterion(pred_final_score, labels)
                    loss_tot = mt_loss(loss_mid_tot, loss_long, labels, predict_intent_score, predict_activity_score, lab_seq_final)
                else:
                    loss_mid = criterion(activity_pred, seqeunce_label)
                    print('loss_mid', loss_mid)
                    loss_long = criterion(pred_final_score, labels)
                    print('loss_long', loss_long)
                    loss_tot = loss_long + (loss_mid)
                    print('loss_tot', loss_tot)

                running_loss += loss_tot.item() * inputs.size(0)
                running_corrects_long += torch.sum(preds == labels.data)
                running_corrects_mid += torch.sum(preds_mid == seq_labs.data)

                running_corrects_long_pred += torch.sum(preds_pred_long == labels.data)
                running_corrects_mid_pred += torch.sum(preds_pred_mid == lab_seq_final.data)

                print('prec1_LI test', prec1_LI, 'prec3_LI test', prec3_LI)
                print('running_corrects_top1',running_corrects_top1,'running_corrects_top3',running_corrects_top3)

                running_corrects_top1 += torch.sum(prec1_LI)
                running_corrects_top3 += torch.sum(prec3_LI)

                running_corrects_top1_MR += torch.sum(prec1_MR)
                running_corrects_top3_MR += torch.sum(prec3_MR)

                running_corrects_top1_IP += torch.sum(prec1_LP)
                running_corrects_top3_IP += torch.sum(prec3_LP)

                running_corrects_top1_MP += torch.sum(prec1_MP)
                running_corrects_top3_MP += torch.sum(prec3_MP)

                running_recog_intent = np.append(running_recog_intent, recog_tmp_intent)
                running_recog_label = np.append(running_recog_label, recog_lab_intent)

                running_recog_activity = np.append(running_recog_activity, recog_tmp_activity)
                running_recog_seqlab = np.append(running_recog_seqlab, lab_seq_act_recog)

                running_pred_intent = np.append(running_pred_intent, preds_tmp_intent)
                running_pred_label = np.append(running_pred_label, lab_tmp_intent)

                running_pred_activity = np.append(running_pred_activity, preds_tmp_act)
                running_pred_seqlab = np.append(running_pred_seqlab, lab_seq_act_final)

            epoch_loss = running_loss / test_size
            epoch_acc_long = running_corrects_long.double() / test_size
            epoch_acc_mid = running_corrects_mid.double() / (test_size*(clip_len))
            epoch_acc_long_pred = running_corrects_long_pred.double() / test_size
            epoch_acc_mid_pred = running_corrects_mid_pred.double() / test_size

            print('running_corrects_top1', running_corrects_top1, 'running_corrects_top3', running_corrects_top3)

            epoch_acc_top1 = running_corrects_top1.double() * bat_size / test_size
            epoch_acc_top3 = running_corrects_top3.double() * bat_size / test_size

            epoch_acc_top1_MR = running_corrects_top1_MR.double() * bat_size / test_size
            epoch_acc_top3_MR = running_corrects_top3_MR.double() * bat_size / test_size

            epoch_acc_top1_IP = running_corrects_top1_IP.double() * bat_size / test_size
            epoch_acc_top3_IP = running_corrects_top3_IP.double() * bat_size / test_size

            epoch_acc_top1_MP = running_corrects_top1_MP.double() * bat_size / test_size
            epoch_acc_top3_MP = running_corrects_top3_MP.double() * bat_size / test_size

            print('epoch_acc_top1', epoch_acc_top1, ' epoch_acc_top3', epoch_acc_top3)
            print('epoch_acc_top1_MR', epoch_acc_top1_MR, ' epoch_acc_top3_MR', epoch_acc_top3_MR)

            epoch_acc_cpu_long = epoch_acc_long.data.cpu().numpy()
            epoch_acc_cpu_mid = epoch_acc_mid.data.cpu().numpy()
            epoch_acc_cpu_long_pred = epoch_acc_long_pred.data.cpu().numpy()
            epoch_acc_cpu_mid_pred = epoch_acc_mid_pred.data.cpu().numpy()

            epoch_acc_cpu_top1 = epoch_acc_top1.data.cpu().numpy()
            epoch_acc_cpu_top3 = epoch_acc_top3.data.cpu().numpy()
            epoch_acc_cpu_top1_MR = epoch_acc_top1_MR.data.cpu().numpy()
            epoch_acc_cpu_top3_MR = epoch_acc_top3_MR.data.cpu().numpy()

            epoch_acc_cpu_top1_IP = epoch_acc_top1_IP.data.cpu().numpy()
            epoch_acc_cpu_top3_IP = epoch_acc_top3_IP.data.cpu().numpy()
            epoch_acc_cpu_top1_MP = epoch_acc_top1_MP.data.cpu().numpy()
            epoch_acc_cpu_top3_MP = epoch_acc_top3_MP.data.cpu().numpy()

            testing_loss = np.append(testing_loss, epoch_loss)
            testing_acc_long = np.append(testing_acc_long, epoch_acc_cpu_long)
            testing_acc_activity = np.append(testing_acc_activity, epoch_acc_cpu_mid)
            testing_acc_long_pred = np.append(testing_acc_long_pred, epoch_acc_cpu_long_pred)
            testing_acc_activity_pred = np.append(testing_acc_activity_pred, epoch_acc_cpu_mid_pred)

            testing_acc_top1 = np.append(testing_acc_top1, epoch_acc_cpu_top1)
            testing_acc_top3 = np.append(testing_acc_top3, epoch_acc_cpu_top3)
            testing_acc_top1_MR = np.append(testing_acc_top1_MR, epoch_acc_cpu_top1_MR)
            testing_acc_top3_MR = np.append(testing_acc_top3_MR, epoch_acc_cpu_top3_MR)

            testing_acc_top1_IP = np.append(testing_acc_top1_IP, epoch_acc_cpu_top1_IP)
            testing_acc_top3_IP = np.append(testing_acc_top3_IP, epoch_acc_cpu_top3_IP)
            testing_acc_top1_MP = np.append(testing_acc_top1_MP, epoch_acc_cpu_top1_MP)
            testing_acc_top3_MP = np.append(testing_acc_top3_MP, epoch_acc_cpu_top3_MP)

            print('epoch_acc_cpu_long:',epoch_acc_cpu_long, 'max_acc:', max_acc)
            if epoch_acc_cpu_long > max_acc:
                max_acc = epoch_acc_cpu_long
                wrtie_confusion_results(running_recog_intent, running_recog_label, running_recog_activity,
                                        running_recog_seqlab, '_recog_max_confusion.txt')
                wrtie_confusion_results(running_pred_intent, running_pred_label, running_pred_activity,
                                        running_pred_seqlab, '_pred_max_confusion.txt')

                print("Save model at {}\n".format(
                    os.path.join(save_dir, 'models', saveName + '_epoch-' + str(epoch) + '.pth')))

            writer.add_scalar('data/test_max_acc_epoch', max_acc, epoch)
            writer.add_scalar('data/test_loss_epoch', epoch_loss, epoch)
            writer.add_scalar('data/test_acc_long_epoch', epoch_acc_long, epoch)
            writer.add_scalar('data/test_acc_mid_epoch', epoch_acc_mid, epoch)
            writer.add_scalar('data/test_acc_long_pred_epoch', epoch_acc_long_pred, epoch)
            writer.add_scalar('data/test_acc_mid_pred_epoch', epoch_acc_mid_pred, epoch)

            writer.add_scalar('data/test_top1_epoch', epoch_acc_cpu_top1, epoch)
            writer.add_scalar('data/test_top3_epoch', epoch_acc_cpu_top3, epoch)
            writer.add_scalar('data/epoch_acc_cpu_top1_MR', epoch_acc_cpu_top1_MR, epoch)
            writer.add_scalar('data/epoch_acc_cpu_top3_MR', epoch_acc_cpu_top3_MR, epoch)
            writer.add_scalar('data/epoch_acc_cpu_top1_IP', epoch_acc_cpu_top1_IP, epoch)
            writer.add_scalar('data/epoch_acc_cpu_top3_IP', epoch_acc_cpu_top3_IP, epoch)
            writer.add_scalar('data/epoch_acc_cpu_top1_MP', epoch_acc_cpu_top1_MP, epoch)
            writer.add_scalar('data/epoch_acc_cpu_top3_MP', epoch_acc_cpu_top3_MP, epoch)

            writer.flush()

            print("[test] Epoch: {}/{} Loss: {} Intent Acc: {} Activity Acc: {} Pred Intent Acc: {} Pred Activity Acc: {}"
                  .format(epoch + 1, nEpochs, epoch_loss, epoch_acc_long, epoch_acc_mid, epoch_acc_long_pred, epoch_acc_mid_pred))

            stop_time = timeit.default_timer()
            print("Execution time: " + str(stop_time - start_time) + "\n")

    print('testing', testing_acc_long, testing_acc_activity, testing_acc_long_pred, testing_acc_activity_pred)
    print('plt plot')

    wrtieresults_train(training_acc_long, training_acc_activity, training_acc_long_pred, training_acc_activity_pred)
    wrtieresults(testing_acc_long, testing_acc_activity, testing_acc_long_pred, testing_acc_activity_pred)

    wrtieresults_top_recog(testing_acc_top1, testing_acc_top3, testing_acc_top1_MR, testing_acc_top3_MR)
    wrtieresults_top_pred(testing_acc_top1_IP, testing_acc_top3_IP, testing_acc_top1_MP, testing_acc_top3_MP)

    writer.close()


pred_horizon = np.arange(1,2)
# pred_horizon = 0
print(pred_horizon)
for pre_h in pred_horizon:
    print('pre_h', pre_h)
    train_model(dataset=dataset, save_dir=save_dir, num_classes=num_classes, num_activities=num_activtites,  lr=lr,
            num_epochs=nEpochs, save_epoch=snapshot, useTest=useTest, test_interval=nTestInterval, pred_horizon=pre_h)