In [None]:
import torch
from torch import nn, optim
from torchvision import transforms, utils
from torch.utils.data import DataLoader, random_split
from dataset import PosterDataset, Resize, ToTensor
import numpy as np

bs = 32
epochs = 50
lr = 1e-3

def props_to_onehot(props):
    if isinstance(props, list):
        props = np.array(props)
    a = np.argmax(props, axis=1)
    b = np.zeros((len(a), props.shape[1]))
    b[np.arange(len(a)), a] = 1
    return b

transformed_dataset = PosterDataset(csv_file='./data.txt',
                                    root_dir='../posters',
                                    transform=transforms.Compose([
                                        Resize(),
                                        ToTensor()
                                    ]))
train_size = int(0.8*len(transformed_dataset)+1)
test_size = int(0.2*len(transformed_dataset))
train_dataset, test_dataset = random_split(transformed_dataset, [train_size, test_size])
data_loader1 = DataLoader(train_dataset, batch_size=bs,shuffle=True)
data_loader2 = DataLoader(test_dataset, batch_size=bs,shuffle=True)
print('train batches: ', len(data_loader1))
print('test batches: ', len(data_loader2))

device = torch.device('cuda')
model = torch.load('net.pkl').cuda()
train_acc = []

for epoch in range(1):
    model.eval()
    with torch.no_grad():
        # test
        total_correct = 0
        total_num = 0
        for idx, item in enumerate(data_loader2):
            x, labels = item['image'].to(device), item['labels']
            
            # 改进
            logits = model(x)
            pred = props_to_onehot(logits.cpu().numpy()) # 输出转为onehot
            
            for i in range(x.size(0)): # 不能设置为tbx因为最后一个batch可能会越界
                a = torch.IntTensor(pred[i]).to(device)
                b = torch.IntTensor(list(map(int, labels[i]))).to(device)
                total_correct += (a*b).sum().item() # 向量点乘，若模型预测结果在电影类型中则点乘为1，即预测正确
                
            total_num += x.size(0)

        acc = total_correct / total_num
        train_acc.append(acc)
        print('epoch: ', epoch, '\tacc: ', acc, '\n')

train batches:  70
test batches:  18


