In [3]:
import os
import argparse

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

from my_dataset import MyDataSet_1
from data_set import CustomCWRUDataset,CWRU
from model import swin_tiny_patch4_window7_224 as create_model
from utils import read_split_data, train_one_epoch, evaluate


def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    if os.path.exists("./weights") is False:
        os.makedirs("./weights")

    tb_writer = SummaryWriter()

    data_1 = CWRU(root_dir=args.data_path)

    train_dataset, val_dataset = data_1.train_test_split_order()
    
    batch_size = args.batch_size




    
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               )

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             )

    model = create_model(num_classes=args.num_classes).to(device)

    if args.weights != "":
        assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
        weights_dict = torch.load(args.weights, map_location=device)["model"]
        # 删除有关分类类别的权重
        for k in list(weights_dict.keys()):
            if "head" in k:
                del weights_dict[k]
        print(model.load_state_dict(weights_dict, strict=False))

    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除head外，其他权重全部冻结
            if "head" not in name:
                para.requires_grad_(False)
            else:
                print("training {}".format(name))

    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=5E-2)

    for epoch in range(args.epochs):
        # train
        train_loss, train_acc = train_one_epoch(model=model,
                                                optimizer=optimizer,
                                                data_loader=train_loader,
                                                device=device,
                                                epoch=epoch)

        # validate
        val_loss, val_acc = evaluate(model=model,
                                     data_loader=val_loader,
                                     device=device,
                                     epoch=epoch)

        tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
        tb_writer.add_scalar(tags[0], train_loss, epoch)
        tb_writer.add_scalar(tags[1], train_acc, epoch)
        tb_writer.add_scalar(tags[2], val_loss, epoch)
        tb_writer.add_scalar(tags[3], val_acc, epoch)
        tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)

        torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))


# if __name__ == '__main__':
    # parser = argparse.ArgumentParser()
    # parser.add_argument('--num_classes', type=int, default=9)
    # parser.add_argument('--epochs', type=int, default=10)
    # parser.add_argument('--batch-size', type=int, default=8)
    # parser.add_argument('--lr', type=float, default=0.0008)

    # # 数据集所在根目录
    # # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
    # parser.add_argument('--data-path', type=str,
    #                     default="D:/BaiduNetdiskDownload/UAV_Fault_Dataset/Train")

    # # 预训练权重路径，如果不想载入就设置为空字符
    # parser.add_argument('--weights', type=str, default='',
    #                     help='initial weights path')
    # # 是否冻结权重
    # parser.add_argument('--freeze-layers', type=bool, default=False)
    # parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')

    # opt = parser.parse_args()

    # main(opt)

In [4]:
path = "D:/BaiduNetdiskDownload/UAV_Fault_Dataset/Train"
data_1 = CWRU(root_dir=path)

train_dataset, val_dataset = data_1.train_test_split_order()

In [8]:
train_dataset.data_pd

Unnamed: 0,data,label
13629,D:/BaiduNetdiskDownload/UAV_Fault_Dataset/Trai...,2
61094,D:/BaiduNetdiskDownload/UAV_Fault_Dataset/Trai...,10
49414,D:/BaiduNetdiskDownload/UAV_Fault_Dataset/Trai...,8
53181,D:/BaiduNetdiskDownload/UAV_Fault_Dataset/Trai...,10
51001,D:/BaiduNetdiskDownload/UAV_Fault_Dataset/Trai...,8
...,...,...
8017,D:/BaiduNetdiskDownload/UAV_Fault_Dataset/Trai...,0
78923,D:/BaiduNetdiskDownload/UAV_Fault_Dataset/Trai...,14
92065,D:/BaiduNetdiskDownload/UAV_Fault_Dataset/Trai...,16
23442,D:/BaiduNetdiskDownload/UAV_Fault_Dataset/Trai...,4


In [10]:
train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=50,
                                            shuffle=True,
                                            pin_memory=True,
                                            num_workers=2,
                                            )

In [13]:
train_loader

TypeError: 'BatchSampler' object is not callable

In [11]:
for epoch in range(1):
        for step, (batch_x, batch_y) in enumerate(train_loader):
            # training
            print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))


steop:0, batch_x:tensor([[ 6.7986e-02,  1.5764e-01,  1.2074e-01,  ...,  1.5436e-01,
          3.8256e-02,  3.3566e-02],
        [-1.6771e-01, -1.1446e-01, -1.6113e-01,  ..., -1.3407e-02,
         -8.7245e-02, -6.0785e-04],
        [ 1.4448e-01,  1.8626e-01,  2.0455e-01,  ...,  7.8439e-03,
          2.4354e-04,  8.1494e-02],
        ...,
        [ 2.5075e-01,  2.1021e-01,  2.9214e-01,  ..., -2.8464e-01,
         -2.2513e-01, -3.4164e-01],
        [-4.8891e-01, -5.1333e-01, -6.0393e-01,  ..., -2.9323e-01,
         -2.1014e-01, -2.0760e-01],
        [ 2.7761e-02,  7.0255e-02,  5.5642e-02,  ...,  1.4704e-01,
          1.2047e-01,  1.1212e-01]]), batch_y:tensor([ 0,  0,  2, 12,  4, 16,  8, 10,  4,  4, 14, 16,  6, 12,  2,  8,  2,  4,
         6,  6, 12, 16, 16,  2,  0,  6, 10,  6, 16,  0,  2, 14, 14, 10, 16, 14,
        10,  0, 16, 16,  0, 16, 16, 10,  4, 16,  4,  2,  8, 16])
steop:1, batch_x:tensor([[ 0.3636,  0.2075,  0.2173,  ..., -0.2752, -0.2870, -0.2920],
        [-0.5920, -0.4182, -0.

MatReadError: Caught MatReadError in DataLoader worker process 1.
Original Traceback (most recent call last):
  File "d:\Program\Anaconda\envs\DL\lib\site-packages\hdf5storage\__init__.py", line 1777, in loadmat
    with h5py.File(filename, mode='r') as f:
  File "d:\Program\Anaconda\envs\DL\lib\site-packages\h5py\_hl\files.py", line 562, in __init__
    fid = make_fid(name, mode, userblock_size, fapl, fcpl, swmr=swmr)
  File "d:\Program\Anaconda\envs\DL\lib\site-packages\h5py\_hl\files.py", line 235, in make_fid
    fid = h5f.open(name, flags, fapl=fapl)
  File "h5py\_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py\_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py\h5f.pyx", line 102, in h5py.h5f.open
FileNotFoundError: [Errno 2] Unable to synchronously open file (unable to open file: name = 'D:/BaiduNetdiskDownload/UAV_Fault_Dataset/Train\MF1\MF1_330.mat.baiduyun.downloading.cfg.mat', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "d:\Program\Anaconda\envs\DL\lib\site-packages\torch\utils\data\_utils\worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "d:\Program\Anaconda\envs\DL\lib\site-packages\torch\utils\data\_utils\fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "d:\Program\Anaconda\envs\DL\lib\site-packages\torch\utils\data\_utils\fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "d:\Learn\deep-learning-for-image-processing-master\pytorch_classification\swin_transformer\data_set.py", line 17, in __getitem__
    data = hdf5storage.loadmat(file_path)['single_data']#提取数据并转置
  File "d:\Program\Anaconda\envs\DL\lib\site-packages\hdf5storage\__init__.py", line 1811, in loadmat
    return scipy.io.loadmat(file_name, mdict, appendmat=appendmat,
  File "d:\Program\Anaconda\envs\DL\lib\site-packages\scipy\io\matlab\_mio.py", line 226, in loadmat
    MR, _ = mat_reader_factory(f, **kwargs)
  File "d:\Program\Anaconda\envs\DL\lib\site-packages\scipy\io\matlab\_mio.py", line 74, in mat_reader_factory
    mjv, mnv = _get_matfile_version(byte_stream)
  File "d:\Program\Anaconda\envs\DL\lib\site-packages\scipy\io\matlab\_miobase.py", line 232, in _get_matfile_version
    raise MatReadError("Mat file appears to be empty")
scipy.io.matlab._miobase.MatReadError: Mat file appears to be empty
