<a href="https://colab.research.google.com/github/CanZheng0331/Sensing_aided_Communications/blob/main/Computer_Vision_Aided_Beam_Tracking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This is a Python code package related to the following article: S. Jiang and A. Alkhateeb, "**Computer Vision Aided Beam Tracking** in A Real-World Millimeter Wave Deployment," in IEEE Globecom Workshops, 2022.

The original open sourced code is in: https://github.com/acyiobs/vision_beam_tracking/tree/main.

Most of the annotations are written in Chinese.

In [1]:
## 下载 场景8 数据集、视觉数据标注（YOLO V4 识别的 2D 边界框的坐标和属性）并解压

!wget -q -O scenario8.zip "https://www.dropbox.com/scl/fi/6yhd4xddx7d2zk4mcllq6/scenario8.zip?rlkey=w472s91e9tvv5p78yszkxmww3&e=1&dl=0"
!wget -q -O camera_data_bbox.zip "https://www.dropbox.com/scl/fi/n1cqbxvpzxl9j4zhhgs3q/camera_data_bbox.zip?dl=0&e=1&rlkey=cibk7natbsm2axb8gzrvz12rl"
!unzip scenario8.zip
!unzip camera_data_bbox.zip -d "/content/DEV[95%]/unit1"

[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
  inflating: DEV[95%]/unit2/GPS_data_calibrated/gps_location_4099.txt  
  inflating: DEV[95%]/unit2/GPS_data_calibrated/gps_location_41.txt  
  inflating: DEV[95%]/unit2/GPS_data_calibrated/gps_location_410.txt  
  inflating: DEV[95%]/unit2/GPS_data_calibrated/gps_location_4100.txt  
  inflating: DEV[95%]/unit2/GPS_data_calibrated/gps_location_4101.txt  
  inflating: DEV[95%]/unit2/GPS_data_calibrated/gps_location_4102.txt  
  inflating: DEV[95%]/unit2/GPS_data_calibrated/gps_location_4103.txt  
  inflating: DEV[95%]/unit2/GPS_data_calibrated/gps_location_4104.txt  
  inflating: DEV[95%]/unit2/GPS_data_calibrated/gps_location_4105.txt  
  inflating: DEV[95%]/unit2/GPS_data_calibrated/gps_location_4106.txt  
  inflating: DEV[95%]/unit2/GPS_data_calibrated/gps_location_4107.txt  
  inflating: DEV[95%]/unit2/GPS_data_calibrated/gps_location_4108.txt  
  inflating: DEV[95%]/unit2/GPS_data_calibrated/gps_location_4109.txt  
  inflating: DEV[95%]/unit

In [None]:
!pip install pytorch_model_summary

In [2]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import copy
import torch
from torch.utils.data import Dataset
import ast
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from pytorch_model_summary import summary
import sys
import datetime
from scipy.io import savemat
from collections import OrderedDict

In [3]:
## 序列生成
## Sequence Generator

# 阻塞预测函数
def blockage_prediction(csv_frame, data_list_y):
    data_label = np.empty((0, 1), dtype=int)
    for i in np.arange(len(data_list_y)):
        blockage = np.zeros(data_list_y[0])
        for j in range(len(data_list_y[i])):
            blockage[j] = csv_frame[csv_frame.index == data_list_y[i][j]]['blockage'].item()
        blockage = np.sum(blockage) > 0
        data_label = np.append(data_label, blockage)
    return data_label

# 波束预测函数：该函数的目标是预测 给定 时刻或给定数据条件下最合适的波束索引
def beam_prediction(csv_frame, data_list_y):
    data_label = np.empty((0, 1), dtype=int)
    for i in tqdm(np.arange(len(data_list_y))):
        power_levels = np.loadtxt(csv_frame[csv_frame.index == data_list_y[i][0]]['unit1_pwr_60ghz'].item())
        beam_idx = np.argmax(power_levels) + 1
        data_label = np.append(data_label, beam_idx)
    return data_label

# 波束追踪函数：该函数的目标是追踪 多个 时刻或多个位置下的波束选择，并且追踪的是一个时间序列或多个索引对应的波束
def beam_tracking(csv_frame, data_list_y):
    data_label = np.empty((len(data_list_y), len(data_list_y[0])), dtype=int)
    base_path = '/content/scenario8/unit1/mmWave_data/'
    for i in tqdm(np.arange(len(data_list_y))):
        beam_idx = np.zeros(len(data_list_y[0]))
        for j in range(len(data_list_y[i])):
            file_name = csv_frame[csv_frame.index == data_list_y[i][j]]['unit1_pwr_60ghz'].item()
            power_levels =  os.path.join(base_path, file_name)
            beam_idx[j] = np.argmax(power_levels) + 1
        data_label[i] = beam_idx
    return data_label


class TimeSeriesGenerator:
    def __init__(self,
                 csv_file='scenario8.csv',
                 x_size=5,
                 y_size=1,
                 delay=0,
                 seed=5,
                 label_function=beam_prediction,
                 save_filename=None):

        # x_size: Size of input samples
        # y_size: Size of label samples to generate the labels
        # delay: The number of samples between x and y sequences.

        # Example Series: x_size=3, y_size=1, delay=0 --> [1 2 3] [4]
        # Example Series: x_size=3, y_size=1, delay=1 --> [1 2 3] [5]
        # Example Series: x_size=3, y_size=2, delay=0 --> [1 2 3] [4, 5]

        # For Beam Prediction, y_size=1
        # For Blockage Prediction, y_size=# of samples in a blockage duration (e.g., 3 samples)

        if save_filename == None:
            self.save_filename = csv_file.split('.')[0] + '_series' + '.csv'
        else:
            self.save_filename = save_filename

        self.csv_frame = pd.read_csv(csv_file, index_col='index')
        self.num_sequences = self.csv_frame['seq_index'].max()

        self.seq_start = []
        self.seq_end = []
        self._extract_seq_start_end()

        self.x_size = x_size
        self.y_size = y_size
        self.delay = delay

        self.data_list_x = np.empty((0, x_size), dtype=int)
        self.data_list_y = np.empty((0, y_size), dtype=int)
        self.data_list_seq = np.empty((0, y_size), dtype=int)

        self._generate_indices()

        self.data_labels = label_function(self.csv_frame, self.data_list_y)

        # Shuffling indices
        self.num_datapoints = len(self.data_list_y)
        self.data_idx = np.arange(self.num_datapoints)
        rng = np.random.default_rng(seed)
        rng.shuffle(self.data_idx)

        # Shuffling sequences
        self.seq_idx = np.arange(self.num_sequences)
        rng = np.random.default_rng(seed)
        rng.shuffle(self.seq_idx)

    def _extract_seq_start_end(self):
        for i in np.arange(self.num_sequences) + 1:
            data_indices = self.csv_frame[self.csv_frame['seq_index'] == i].index
            self.seq_start.append(data_indices.min())
            self.seq_end.append(data_indices.max())

    def _generate_indices(self):
        for i in range(len(self.seq_start)):
            x_start_ind = self.seq_start[i]
            x_end_ind = x_start_ind + self.x_size

            y_start_ind = x_end_ind + self.delay
            y_end_ind = y_start_ind + self.y_size

            while y_end_ind <= self.seq_end[i] + 1:
                self.data_list_x = np.vstack((self.data_list_x, np.arange(x_start_ind, x_end_ind)))
                self.data_list_y = np.vstack((self.data_list_y, np.arange(y_start_ind, y_end_ind)))
                self.data_list_seq = np.append(self.data_list_seq, i)
                x_start_ind += 1
                x_end_ind += 1
                y_start_ind += 1
                y_end_ind += 1

    def take(self, num_of_data):
        new_dataset = copy.copy(self)
        new_dataset.data_idx = new_dataset.data_idx[:num_of_data]
        new_dataset.num_datapoints = len(new_dataset.data_idx)
        return new_dataset

    def skip(self, num_of_data):
        new_dataset = copy.copy(self)
        new_dataset.data_idx = new_dataset.data_idx[num_of_data:]
        new_dataset.num_datapoints = len(new_dataset.data_idx)
        return new_dataset

    def take_by_idx(self, idx):
        new_dataset = copy.copy(self)
        new_dataset.data_idx = new_dataset.data_idx[idx]
        new_dataset.num_datapoints = len(new_dataset.data_idx)
        return new_dataset

    def __len__(self):
        return self.num_datapoints

    def save_split_files(self, split=(0.7, 0.2, 0.1), data_path_csv_column=None, label_path_csv_column=None,
                         split_names=('train', 'val', 'test'), label_name='beam_index', sequence_split=False,
                         save_y_ind=False):
        if sequence_split:
            num_sequences = self.num_sequences
            num_train = int(num_sequences * split[0])
            num_val = int(num_sequences * split[1])
            idx_train = np.where(np.in1d(self.data_list_seq[self.data_idx], self.seq_idx[:num_train]))
            idx_val = np.where(np.in1d(self.data_list_seq[self.data_idx], self.seq_idx[num_train:num_train + num_val]))
            idx_test = np.where(np.in1d(self.data_list_seq[self.data_idx], self.seq_idx[num_train + num_val:]))


        else:
            # Train, Validation, Test
            num_datapoints = len(self)
            num_train = int(num_datapoints * split[0])
            num_val = int(num_datapoints * split[1])
            idx_train = np.arange(0, num_train)
            idx_val = np.arange(num_train, num_train + num_val)
            idx_test = np.arange(num_train + num_val, num_datapoints)

        idx_list = [idx_train, idx_val, idx_test]
        for n, name in enumerate(split_names):
            self.take_by_idx(idx_list[n]).save_file(file_tag=name, data_path_csv_column=data_path_csv_column,
                                                    label_path_csv_column=label_path_csv_column, label_name=label_name,
                                                    shuffled=True, save_y_ind=save_y_ind)
        '''
        self.take_by_idx(idx_train).save_file(file_tag=split_names[0], data_path_csv_column=data_path_csv_column, label_path_csv_column=label_path_csv_column, label_name=label_name, shuffled=True, save_y_ind=False)
        self.take_by_idx(idx_val).save_file(file_tag=split_names[1], data_path_csv_column=data_path_csv_column, label_path_csv_column=label_path_csv_column, label_name=label_name, shuffled=True, save_y_ind=False)
        self.take_by_idx(idx_test).save_file(file_tag=split_names[2], data_path_csv_column=data_path_csv_column, label_path_csv_column=label_path_csv_column, label_name=label_name, shuffled=True,save_y_ind=False)
        '''

    # data_path_csv_column: If the location of the sequences are required
    # in the output csv file, input the name of the csv column
    # (e.g., 'unit1_radar').
    def save_file(self, file_tag='', data_path_csv_column=None, label_path_csv_column=None, label_name='label',
                  shuffled=False, save_y_ind=False):
        if data_path_csv_column is None:
            df_x = pd.DataFrame(self.data_list_x, columns=['x_%i' % (i + 1) for i in range(self.x_size)])
        else:
            df_x = pd.DataFrame(self.csv_frame[data_path_csv_column].to_numpy(str)[self.data_list_x - 1],
                                columns=['x_%i' % (i + 1) for i in range(self.x_size)])
        if save_y_ind:
            if label_path_csv_column is None:
                df_y = pd.DataFrame(self.data_list_y, columns=['y_%i' % (i + 1) for i in range(self.y_size)])
            else:
                df_y = pd.DataFrame(self.csv_frame[label_path_csv_column].to_numpy(str)[self.data_list_y - 1],
                                    columns=['y_%i' % (i + 1) for i in range(self.y_size)])
        else:
            df_y = pd.DataFrame()
        df_label = pd.DataFrame(self.data_labels,
                                columns=['%s_%i' % (label_name, i + 1) for i in range(self.data_labels.shape[1])])
        df = pd.concat([df_x, df_y, df_label], axis=1)
        df.index.name = 'index'
        df.index += 1

        if shuffled:
            df = df.iloc[self.data_idx]
            df.index.name = 'data_index'
            df = df.reset_index()
            df.index.name = 'index'
            df.index += 1

        filename = self.save_filename.split('.')[0] + '_' + file_tag + '.csv'
        df.to_csv(filename)
        print('%i data points are saved to %s' % (len(df), filename))


# %% Generate Time Series & Save Series CSV Files

csv_file = '/content/DEV[95%]/scenario8.csv'

# May define your own label extraction function if needed
# blockage_prediction and beam_prediction are currently available
label_function = beam_tracking

# Data sequence size
x_size = 8
# Label sequence size, beam_prediction --> 1, blockage_prediction --> 3
y_size = x_size + 5
# Delay
delay = -x_size

rng_seed = 5  # Reproducibility

# The file name will be included in the series
data_path_csv_column = 'unit1_rgb'
label_path_csv_column = 'unit1_pwr_60ghz'
save_y_ind = True
# Name of the labels
label_name = 'beam_index'

# Sequence or data split of the files
# If False, the data is fully shuffled
# If True, Train-Validation-Test sets are separated by the sequence
# i.e., any of the sets will not have a shared sequence
sequence_split = True

x = TimeSeriesGenerator(csv_file=csv_file,
                        x_size=x_size,
                        y_size=y_size,
                        seed=rng_seed,
                        delay=delay,
                        label_function=label_function,
                        save_filename='/content/DEV[95%]/scenario8_series.csv')

x.save_file(file_tag='full', data_path_csv_column=data_path_csv_column, label_path_csv_column=label_path_csv_column,
            label_name=label_name, shuffled=False, save_y_ind=save_y_ind)
x.save_split_files(split=(0.8, 0.0, 0.2), data_path_csv_column=data_path_csv_column,
                   label_path_csv_column=label_path_csv_column, label_name=label_name, sequence_split=sequence_split,
                   save_y_ind=save_y_ind)

100%|██████████| 3048/3048 [00:11<00:00, 262.79it/s]


3048 data points are saved to /content/DEV[95%]/scenario8_series_full.csv
2435 data points are saved to /content/DEV[95%]/scenario8_series_train.csv
0 data points are saved to /content/DEV[95%]/scenario8_series_val.csv
613 data points are saved to /content/DEV[95%]/scenario8_series_test.csv


In [4]:
## Pre-Processing

def preprocess(in_path, out_path):
    base_path = '/content/DEV[95%]'
    csv_file_path = in_path
    csv_save_path = out_path
    df = pd.read_csv(csv_file_path)

    cols1 = ['x_'+str(s) for s in range(1,9)]
    cols2 = ['y_'+str(s) for s in range(1,14)]
    for index, row in tqdm(df.iterrows(), total=df.shape[0]):
        for c in cols1:
            relative_path = row.loc[c]
            relative_path = relative_path.replace('camera_data', 'camera_data_bbox').replace('jpg', 'txt')
            full_path = os.path.join(base_path, relative_path)
            path = os.path.normpath(os.path.join(base_path, relative_path.lstrip('./')))
            try:
                content = np.loadtxt(path)[1:]
            except:
                content = np.zeros(4)
            if not content.size:
                content = np.zeros(4)
            df.at[index, c] = np.array2string(content, separator=',')
        for c in cols2:
            relative_path = row.loc[c]
            path = os.path.normpath(os.path.join(base_path, relative_path.lstrip('./')))
            content = np.loadtxt(path)
            df.at[index, c] = np.array2string(content, separator=',')

    df.to_csv(csv_save_path, index=False)


#%% training dataset
csv_file_path = '/content/DEV[95%]/scenario8_series_train.csv'
csv_save_path = '/content/DEV[95%]/scenario8_series_bbox_train.csv'
preprocess(csv_file_path, csv_save_path)

#%% test dataset
csv_file_path = '/content/DEV[95%]/scenario8_series_test.csv'
csv_save_path = '/content/DEV[95%]/scenario8_series_bbox_test.csv'
preprocess(csv_file_path, csv_save_path)

100%|██████████| 2435/2435 [00:36<00:00, 66.33it/s]
  content = np.loadtxt(path)[1:]
  content = np.loadtxt(path)[1:]
100%|██████████| 613/613 [00:09<00:00, 63.99it/s]


In [5]:
## Data Feeder

def create_samples(root, portion=1.0, num_beam=64):
    f = pd.read_csv(root)
    bbox_all = []
    beam_power_all = []
    for idx, row in f.iterrows():
        bboxes = np.stack(
            [np.asarray(ast.literal_eval(r)) for r in row.loc["x_1":"x_8"]], axis=0
        )
        bbox_all.append(bboxes)
        beam_powers = np.stack(
            [np.asarray(ast.literal_eval(r)) for r in row.loc["y_1":"y_13"]], axis=0
        )
        beam_power_all.append(beam_powers)

    bbox_all = np.stack(bbox_all, axis=0)
    beam_power_all = np.stack(beam_power_all, axis=0)
    best_beam = np.argmax(beam_power_all, axis=-1)

    print("list is ready")
    num_data = len(beam_power_all)
    num_data = int(num_data * portion)
    return bbox_all[:num_data], best_beam[:num_data, -5:]


class DataFeed(Dataset):
    def __init__(self, root_dir, portion=1.0, num_beam=64):

        self.root = root_dir
        self.samples, self.pred_val = create_samples(
            self.root, portion=portion, num_beam=num_beam
        )
        self.seq_len = 8
        self.num_beam = num_beam

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        samples = self.samples[idx]  # Read one data sample
        pred_val = self.pred_val[idx]

        samples = samples[-self.seq_len :]  # Read a sequence of tuples from a sample

        out_beam = torch.zeros((5,))
        bbox = torch.zeros((self.seq_len, 4))

        if not samples.size:
            samples = np.zeros(4)
        bbox = torch.tensor(samples, requires_grad=False)

        out_beam = torch.tensor(pred_val, requires_grad=False)

        return bbox.float(), out_beam.long()

In [6]:
## Model

class GruModelSimple(nn.Module):
    def __init__(self, num_classes, num_layers=3, hidden_size=64, embed_size=64):
        super(GruModelSimple, self).__init__()
        self.embed = torch.nn.Linear(4, embed_size)
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.gru = torch.nn.GRU(
            input_size=embed_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=0.3,
        )
        self.fc = torch.nn.Linear(hidden_size, num_classes)
        self.name = "GruModelSimple"
        self.dropout1 = nn.Dropout(0.5)

    def initHidden(self, batch_size):
        return torch.zeros((self.num_layers, batch_size, self.hidden_size))

    def forward(self, x, h):
        y = self.embed(x)
        y = self.dropout1(y)
        y, h = self.gru(y, h)
        y = self.fc(y)
        return y, h


In [8]:
## Training

def train_model(num_epoch=100, if_writer=False, portion=1.0, num_beam=64):
    num_classes = num_beam + 1
    batch_size = 8
    val_batch_size = 64
    train_dir = "/content/DEV[95%]/scenario8_series_bbox_train.csv"
    val_dir = "/content/DEV[95%]/scenario8_series_bbox_test.csv"
    train_loader = DataLoader(
        DataFeed(train_dir, num_beam=num_beam), batch_size=batch_size, shuffle=True
    )
    val_loader = DataLoader(
        DataFeed(val_dir, num_beam=num_beam), batch_size=val_batch_size, shuffle=False
    )

    # check gpu acceleration availability
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    now = datetime.datetime.now().strftime("%H_%M_%S")
    date = datetime.date.today().strftime("%y_%m_%d")

    # Instantiate the model
    net = GruModelSimple(num_classes)
    # path to save the model
    checkpoint_dir = "./checkpoint/"
    os.makedirs(checkpoint_dir, exist_ok=True)

    PATH = os.path.join(checkpoint_dir, f"{now}_{date}_{net.name}.pth")
    #PATH = "./checkpoint/" + now + "_" + date + "_" + net.name + "" + ".pth"
    # print model summary
    if if_writer:
        h = net.initHidden(1)
        print(summary(net, torch.zeros((8, 1, 4)), h))
    # send model to GPU
    net.to(device)

    # set up loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[10, 30, 50], gamma=0.1)
    if if_writer:
        writer = SummaryWriter(comment=now + "_" + date + "_" + net.name)
        # writer.add_graph(net, (torch.zeros((1, 1, 1024)), torch.zeros((1, 1, 1024))))

    # train model
    val_top1_acc_his = []
    val_top2_acc_his = []
    val_top3_acc_his = []
    val_top5_acc_his = []

    for epoch in range(num_epoch):  # loop over the dataset multiple times
        net.train()
        running_loss = 0.0
        running_acc = 1.0
        with tqdm(train_loader, unit="batch", file=sys.stdout) as tepoch:
            for i, (bbox, label) in enumerate(tepoch, 0):
                tepoch.set_description(f"Epoch {epoch}")
                # get the input np arrays, bbox sequence (batch_size, 8, 4) label (batch_size, 1)
                #bbox = (bbox - bbox.mean(dim=1, keepdim=True)) / bbox.std(dim=1, keepdim=True)
                bbox = torch.swapaxes(bbox, 0, 1)
                bbox = torch.cat(
                    [bbox, torch.zeros(torch.Size((4,)) + bbox.shape[1:]) - 1], dim=0
                )
                label = torch.swapaxes(label, 0, 1)
                bbox = bbox.to(device)
                label = label.to(device)
                optimizer.zero_grad()

                h = net.initHidden(bbox.shape[1]).to(device)
                outputs, _ = net(bbox, h)
                outputs = outputs[-5:, ...]
                loss = criterion(outputs.view(-1, num_classes), label.flatten())
                prediction = torch.argmax(outputs, dim=-1)
                acc = (prediction == label).sum().item() / int(
                    torch.sum(label != -100).cpu()
                )
                loss.backward()
                optimizer.step()
                # print statistics
                running_loss = (loss.item() + i * running_loss) / (i + 1)
                running_acc = (acc + i * running_acc) / (i + 1)
                log = OrderedDict()
                log["loss"] = running_loss
                log["acc"] = running_acc
                tepoch.set_postfix(log)
            scheduler.step()
            # validation
            predictions = []
            net.eval()
            with torch.no_grad():
                total = np.zeros((5,))
                top1_correct = np.zeros((5,))
                top2_correct = np.zeros((5,))
                top3_correct = np.zeros((5,))
                top5_correct = np.zeros((5,))

                val_loss = 0
                for (bbox, label) in val_loader:
                    bbox = torch.swapaxes(bbox, 0, 1)
                    #bbox = (bbox - bbox.mean(dim=1, keepdim=True)) / bbox.std(dim=1, keepdim=True)
                    bbox = torch.cat(
                        [bbox, torch.zeros(torch.Size((4,)) + bbox.shape[1:]) - 1],
                        dim=0,
                    )
                    label = torch.swapaxes(label, 0, 1)
                    bbox = bbox.to(device)
                    label = label.to(device)
                    optimizer.zero_grad()

                    h = net.initHidden(bbox.shape[1]).to(device)
                    outputs, _ = net(bbox, h)
                    outputs = outputs[-5:, ...]
                    label = label[-5:, ...]
                    val_loss += nn.CrossEntropyLoss(reduction="sum")(
                        outputs.view(-1, num_classes), label.flatten()
                    ).item()
                    total += torch.sum(label != -100, dim=-1).cpu().numpy()
                    prediction = torch.argmax(outputs, dim=-1)
                    top1_correct += torch.sum(prediction == label, dim=-1).cpu().numpy()
                    _, idx = torch.topk(outputs, 5, dim=-1)
                    idx = idx.cpu().numpy()
                    idx = np.minimum(idx, num_beam - 1)
                    label = label.cpu().numpy()
                    for i in range(label.shape[0]):
                        for j in range(label.shape[1]):
                            top2_correct[i] += np.isin(label[i, j], idx[i, j, :2]).sum()
                            top3_correct[i] += np.isin(label[i, j], idx[i, j, :3]).sum()
                            top5_correct[i] += np.isin(label[i, j], idx[i, j, :5]).sum()

                    predictions.append(prediction.cpu().numpy())

                val_loss /= total.sum()
                #scheduler.step(val_loss)
                val_top1_acc = top1_correct / total
                val_top2_acc = top2_correct / total
                val_top3_acc = top3_correct / total
                val_top5_acc = top5_correct / total

                print("val_loss={:.4f}".format(val_loss), flush=True)
                print("accuracy", flush=True)
                print(
                    np.stack(
                        [val_top1_acc, val_top2_acc, val_top3_acc, val_top5_acc], 0
                    ),
                    flush=True,
                )
                print("power", flush=True)
        if if_writer:
            writer.add_scalar("Loss/train", running_loss, epoch)
            writer.add_scalar("Loss/test", val_loss, epoch)
            writer.add_scalar("acc/train", running_acc, epoch)
            writer.add_scalar("acc/test", val_top1_acc[0], epoch)
        val_top1_acc_his.append(val_top1_acc)
        val_top2_acc_his.append(val_top2_acc)
        val_top3_acc_his.append(val_top3_acc)
        val_top5_acc_his.append(val_top5_acc)

    if if_writer:
        writer.close()
        torch.save(net.state_dict(), PATH)

    his = {
        "acc_top1": val_top1_acc_his,
        "acc_top2": val_top2_acc_his,
        "acc_top3": val_top3_acc_his,
        "acc_top5": val_top5_acc_his,
    }
    print("Finished Training")

    # load the model
    net.to(device)
    net.eval()
    # test
    predictions = []
    raw_predictions = []
    net.eval()
    with torch.no_grad():
        total = np.zeros((5,))
        top1_correct = np.zeros((5,))
        top2_correct = np.zeros((5,))
        top3_correct = np.zeros((5,))
        top5_correct = np.zeros((5,))
        val_loss = 0
        for (bbox, label) in val_loader:
            bbox = torch.swapaxes(bbox, 0, 1)
            bbox = torch.cat(
                [bbox, torch.zeros(torch.Size((4,)) + bbox.shape[1:]) - 1], dim=0
            )
            label = torch.swapaxes(label, 0, 1)
            bbox = bbox.to(device)
            label = label.to(device)

            bbox = bbox.to(device)
            label = label.to(device)

            h = net.initHidden(bbox.shape[1]).to(device)
            outputs, _ = net(bbox, h)
            outputs = outputs[-5:, ...]
            label = label[-5:, ...]
            val_loss += nn.CrossEntropyLoss(reduction="sum")(
                outputs.view(-1, num_classes), label.flatten()
            ).item()
            total += torch.sum(label != -100, dim=-1).cpu().numpy()
            prediction = torch.argmax(outputs, dim=-1)
            top1_correct += torch.sum(prediction == label, dim=-1).cpu().numpy()

            _, idx = torch.topk(outputs, 5, dim=-1)
            idx = idx.cpu().numpy()
            label = label.cpu().numpy()
            for i in range(label.shape[0]):
                for j in range(label.shape[1]):
                    top2_correct[i] += np.isin(label[i, j], idx[i, j, :2]).sum()
                    top3_correct[i] += np.isin(label[i, j], idx[i, j, :3]).sum()
                    top5_correct[i] += np.isin(label[i, j], idx[i, j, :5]).sum()

            predictions.append(prediction.cpu().numpy())
            raw_predictions.append(outputs.cpu().numpy())

        val_loss /= total.sum()
        val_top1_acc = top1_correct / total
        val_top2_acc = top2_correct / total
        val_top3_acc = top3_correct / total
        val_top5_acc = top5_correct / total

        predictions = np.concatenate(predictions, 1)
        raw_predictions = np.concatenate(raw_predictions, 1)

        val_acc = {
            "top1": val_top1_acc,
            "top2": val_top2_acc,
            "top3": val_top3_acc,
            "top5": val_top5_acc,
        }
        return val_loss, val_acc, predictions, raw_predictions, his


if __name__ == "__main__":
    torch.manual_seed(42)
    num_epoch = 100
    val_loss, val_acc, predictions, raw_predictions, his = train_model(
        num_epoch=num_epoch, if_writer=True, portion=1.0, num_beam=64
    )
    print(val_acc)

    plot_dir = "./plot"
    os.makedirs(plot_dir, exist_ok=True)

    file_path = os.path.join(plot_dir, "test_acc.mat")
    savemat(file_path, {'test_acc': val_acc})
    print(f"Results saved at: {file_path}")

    savemat('plot/test_acc.mat',{'test_acc':val_acc})

list is ready
list is ready
cuda
------------------------------------------------------------------------------
      Layer (type)               Output Shape         Param #     Tr. Param #
          Linear-1                 [8, 1, 64]             320             320
         Dropout-2                 [8, 1, 64]               0               0
             GRU-3     [8, 1, 64], [3, 1, 64]          74,880          74,880
          Linear-4                 [8, 1, 65]           4,225           4,225
Total params: 79,425
Trainable params: 79,425
Non-trainable params: 0
------------------------------------------------------------------------------
Epoch 0: 100%|██████████| 305/305 [00:04<00:00, 70.95batch/s, loss=2.88, acc=0.226]
val_loss=2.2965
accuracy
[[0.28548124 0.30342577 0.31810767 0.31484502 0.3050571 ]
 [0.42577488 0.45513866 0.48287113 0.50897227 0.51223491]
 [0.51712887 0.54649266 0.57422512 0.59216966 0.60195759]
 [0.69331158 0.71288744 0.72593801 0.73409462 0.7324633 ]]
power
E