In [None]:
import os
import numpy as np
import torch
from torch import nn,einsum
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import torch.utils.data as data
import torchvision
from torch.autograd import Variable
import matplotlib.pyplot as plt
from functions import *
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.metrics import accuracy_score
import pickle

In [None]:

data_path = "/home/hanpeiheng/dataset/yawdd_temp/"
action_name_path = './data.pkl'
save_model_path = "../vivit_ckpt/"

CNN_fc_hidden1, CNN_fc_hidden2 = 256, 128
CNN_embed_dim = 256
img_x, img_y = 80,80
img_size = 80
ptc_size=20
dropout_p = 0.3

RNN_hidden_layers = 3 # 三个隐藏层
RNN_hidden_nodes = 128 # 每个隐藏层512个节点
RNN_FC_dim = 128 # 一个全连接层，其维度为256


k = 3             # number of target category 目标类别的数量为101个
epochs = 100        # training epochs 训练轮数
batch_size = 31     # 每批次训练的样本数量为30
learning_rate = 1e-4 # 学习率为0.0001  学习率参数很重要，自己搜搜看吧
log_interval = 10   # interval for displaying training info 训练过程中打印训练信息的间隔为10
lam=1e-4
weight_decay_global=1e-6
step_size=10
gamma=0.9

begin_frame, end_frame, skip_frame = 11, 91, 2

In [None]:
# 输入参数：包括日志输出间隔(log_interval)、模型(model)，设备(device)、训练数据加载器(train_loader)、优化器(optimizer)和当前epoch
def train(log_interval, model, device, train_loader, optimizer, epoch):
    # set model as training mode
    vivit = model
    
    losses = [] # 损失率
    scores = [] # 准确率
    N_count = 0
    for batch_idx, (X, y) in enumerate(train_loader):
        
        X, y = X.to(device), y.to(device).view(-1, )

        N_count += X.size(0)
        
        optimizer.zero_grad()
        output = vivit(X)
        
        re_loss = 0
        for name, param in model.named_parameters():
            if param.requires_grad:
                re_loss += torch.sum(torch.abs(param))
        loss = F.cross_entropy(output, y) + re_loss * lam
        losses.append(loss.item()) # 将该损失值添加到损失列表(losses)

        y_pred = torch.max(output, 1)[1]

        step_score = accuracy_score(y.cpu().data.squeeze().numpy(), y_pred.cpu().data.squeeze().numpy())
        scores.append(step_score)
        
        loss.backward()
        optimizer.step()

        if (batch_idx + 1) % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, Accu: {:.2f}%'.format(
                epoch , N_count, len(train_loader.dataset), 100. * (batch_idx + 1) / len(train_loader), loss.item(), 100 * step_score))

    return sum(losses)/len(losses), sum(scores)/len(scores)


def validation(model, device, optimizer, test_loader,epoch):

    vivit = model
    
    test_loss = 0
    all_y = []
    all_y_pred = []
    with torch.no_grad():

        for X, y in test_loader:

            X, y = X.to(device), y.to(device).view(-1, )
            output = vivit(X)
            re_loss = 0
            for name, param in model.named_parameters():
                if param.requires_grad:
                    re_loss += torch.sum(torch.abs(param))
            
                    
            loss = F.cross_entropy(output, y, reduction='sum') + re_loss * lam
            test_loss += loss.item()

            y_pred = output.max(1, keepdim=True)[1]

            all_y.extend(y)
            all_y_pred.extend(y_pred)

    test_loss /= len(test_loader.dataset)

    all_y = torch.stack(all_y, dim=0)
    all_y_pred = torch.stack(all_y_pred, dim=0)
    test_score = accuracy_score(all_y.cpu().data.squeeze().numpy(), all_y_pred.cpu().data.squeeze().numpy())

    print('Test set ({:d} samples): Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(len(all_y), test_loss, 100* test_score))
    torch.save(vivit.state_dict(), os.path.join(save_model_path, 'vivit_epoch{}.pth'.format(epoch)))  # save spatial_encoder
    torch.save(optimizer.state_dict(), os.path.join(save_model_path, 'optimizer_epoch{}.pth'.format(epoch)))      # save optimizer
    print("Epoch {} model saved!".format(epoch))
    
    return test_loss, test_score

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")   # use CPU or GPU

params = {'batch_size': batch_size, 'shuffle': True, 'num_workers': 6, 'pin_memory': True} if use_cuda else {}

In [None]:
with open(action_name_path, 'rb') as f:
    action_names = pickle.load(f)

le = LabelEncoder()
le.fit(action_names)

list(le.classes_)


action_category = le.transform(action_names).reshape(-1, 1)

enc = OneHotEncoder()
enc.fit(action_category)




In [None]:
actions = []
fnames = os.listdir(data_path)

all_names = []
for f in fnames:
    loc1 = f.find('s-')
    if loc1==-1:
        loc1=f.find('d-')
        
        if loc1 == -1:
            # print(loc1)
            loc1=f.find('e-')
            
    # loc2 = f.find('_g')
    actions.append(f[(loc1 + 2): ])

    all_names.append(f)

all_X_list = all_names
all_y_list = labels2cat(le, actions)    # all video labels

train_list, test_list, train_label, test_label = train_test_split(all_X_list, all_y_list, test_size=0.2, random_state=42)

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.229, 0.224, 0.225])])

selected_frames = np.arange(begin_frame, end_frame, skip_frame).tolist()

# functions.py里的类
# 反斜线 \ 是 Python 中的行连接符，用于将一行代码分成多行来提高代码的可读性和易维护性
train_set, valid_set = Dataset_CRNN(data_path, train_list, train_label, selected_frames, transform=transform), \
                       Dataset_CRNN(data_path, test_list, test_label, selected_frames, transform=transform)

train_loader = data.DataLoader(train_set, **params)
valid_loader = data.DataLoader(valid_set, **params)

# vivit = ViViT(image_size=img_size, patch_size=ptc_size, num_classes=3, num_frames=30,dropout=dropout_p).to(device)
vivit = ViViTBackbone(device=device).to(device)

if torch.cuda.device_count() > 1:
    vivit = nn.DataParallel(vivit,device_ids=[0,1])
    
vivit_params = list(vivit.parameters()) 
optimizer = torch.optim.Adam(vivit_params, lr=learning_rate,weight_decay=weight_decay_global)
StepLR = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)


In [None]:
# record training process
epoch_train_losses = []
epoch_train_scores = []
epoch_test_losses = []
epoch_test_scores = []

# start training
for epoch in range(epochs):
    # train, test model
    train_losses, train_scores = train(log_interval, vivit, device, train_loader, optimizer, epoch)
    epoch_test_loss, epoch_test_score = validation(vivit, device, optimizer, valid_loader,epoch)

    # save results
    # train_scores_avg=sum(train_scores)/len(train_scores)
    epoch_train_losses.append(train_losses)
    epoch_train_scores.append(train_scores)
    epoch_test_losses.append(epoch_test_loss)
    epoch_test_scores.append(epoch_test_score)

    # save all train test results
    A = np.array(epoch_train_losses)
    B = np.array(epoch_train_scores)
    C = np.array(epoch_test_losses)
    D = np.array(epoch_test_scores)
    np.save('./vivit_epoch_training_losses.npy', A)
    np.save('./vivit_epoch_training_scores.npy', B)
    np.save('./vivit_epoch_test_loss.npy', C)
    np.save('./vivit_epoch_test_score.npy', D)

# plot
fig = plt.figure(figsize=(10, 4)) #创建一个10x4英寸的新窗口
plt.subplot(121) # 参数121表示将整个图形窗口分成1行2列，在第1个位置上添加子图，设为当前绘图区域
#  绘制折线图，在当前子图中绘制以训练时期为横坐标、损失值为纵坐标的折线图
# 参数np.arange(1, epochs + 1)用于生成一个1到epochs的整数序列，表示训练时期的编号
# A[:, -1]表示将数组A的最后一列作为y轴的数据，即每个训练时期的最后一个batch的损失值
plt.plot(np.arange(1, epochs + 1), A)  # train loss (on epoch end)
plt.plot(np.arange(1, epochs + 1), C)         #  test loss (on epoch end)
plt.title("model loss")
plt.xlabel('epochs')
plt.ylabel('loss')
plt.legend(['train', 'test'], loc="upper left")
# 2nd figure 
#代码解释和上面的同理
plt.subplot(122)
plt.plot(np.arange(1, epochs + 1), B)  # train accuracy (on epoch end)
plt.plot(np.arange(1, epochs + 1), D)         #  test accuracy (on epoch end)
plt.title("training scores")
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.legend(['train', 'test'], loc="upper left")
title = "./fig_yawdd_vivit.png"
plt.savefig(title, dpi=600)
# plt.close(fig)
plt.show()

In [None]:
import os
dirs = os.listdir("/hy-tmp/yawdd_temp/")
for dir_name in dirs:
    lst = os.listdir("/hy-tmp/yawdd_temp/"+dir_name)
    if len(lst)<100:
        print(dir_name)