### 1. 加载模块

In [1]:
%load_ext autoreload
%autoreload 2
from utils import load_all
from Loss_Function import CLOULoss, DistributionFocalLoss
from Model import Model
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
import config

### 2.测试csv

In [2]:
# image, csv = load_all(config.DIRPATH)

In [3]:
# csv_slice = csv[3].iloc[:, 3:].values
# csv

### 3. 定义Dataset

In [4]:
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((640, 640)),
    transforms.Normalize((0.5,), (0.5,))
])
class_to_index = {
    'E2': 0, 'B52': 1, 'B2': 2, 'Mirage2000': 3, 'F4': 4,
    'F14': 5, 'Tornado': 6, 'J20': 7, 'JAS39': 8
}

class MyDataset(Dataset):
    '''
    将图片和[class, Bbox]转化为Tensor
    返回值：
    1. image: [bs, 640, 640, 3]
    2. target: [bs, 物品个数, class, xmin, ymin, xmax, ymax]
    '''
    def __init__(self, path, transform):
        super().__init__()
        self.image, self.csv = load_all(path)
        self.transform = transform
    def __len__(self):
        return len(self.image)

    def __getitem__(self, index):
        image = self.transform(self.image[index])
        csv = self.csv[index]
        # filename width height class xmin ymin xmax ymax
        target_array = csv.iloc[:, 3:].values
        target = torch.zeros((20, 5))
        for i in range(len(target_array)):
            target[i, 0] = class_to_index[target_array[i][0]]
            target[i, 1] = target_array[i][1]
            target[i, 2] = target_array[i][2]
            target[i, 3] = target_array[i][3]
            target[i, 4] = target_array[i][4]

        return image, target

### 4. 定义Dataloader, net, Loss_Function, optimizer

In [5]:
dataset = MyDataset(config.DIRPATH, data_transform)

In [6]:
device = torch.device('cuda')
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)
F1 = nn.BCELoss()
F2 = CLOULoss()
F3 = DistributionFocalLoss()
reg_max, num_class = 20, 9
net = Model(reg_max, num_class).to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.937)

### 5. 训练过程

In [7]:
torch.cuda.empty_cache()

In [8]:
for i, (image, target) in enumerate(dataloader):
    image = image.to(device)
    target = target.to(device)
    BLS1, CLS1, BLS2, CLS2, BLS3, CLS3 = net(image)
    print(BLS1.shape, BLS2.shape, BLS3.shape, CLS1.shape, CLS2.shape, CLS3.shape)
    BLS = torch.cat((BLS1, BLS2, BLS3), dim=1) # (b, 8400, 80)
    CLS = torch.cat((CLS1, CLS2, CLS3), dim=1) # (b, 8400, 4)
    print(BLS.shape, CLS.shape)
    break

torch.Size([4, 80, 80, 80]) torch.Size([4, 80, 40, 40]) torch.Size([4, 80, 20, 20]) torch.Size([4, 9, 80, 80]) torch.Size([4, 9, 40, 40]) torch.Size([4, 9, 20, 20])


RuntimeError: Sizes of tensors must match except in dimension 2. Got 20 and 80 (The offending index is 0)