In [1]:
import torch
from torch import nn, optim
from torchvision import transforms, utils
from torch.utils.data import DataLoader, random_split
from models import resnet18
from dataset import PosterDataset, Resize, ToTensor
from matplotlib import pyplot as plt
%matplotlib inline

In [2]:
transformed_dataset = PosterDataset(csv_file='./data.txt',
                                    root_dir='../posters/',
                                    transform=transforms.Compose([
                                        Resize(),
                                        ToTensor()
                                    ]))
train_size = int(0.8*len(transformed_dataset)+1)
test_size = int(0.2*len(transformed_dataset))
train_dataset, test_dataset = random_split(transformed_dataset, [train_size, test_size])
data_loader1 = DataLoader(train_dataset, batch_size=8,shuffle=True)
data_loader2 = DataLoader(test_dataset, batch_size=8,shuffle=True)
print(len(data_loader1))
print(len(data_loader2))

278
70


In [3]:
device = torch.device('cuda')
model = resnet18().to(device)
criteon = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
train_loss = []
train_acc = []

for epoch in range(1):
    model.train()
    for idx, item in enumerate(data_loader1):
#         print(idx)
#         print(x)
#         print(label)
#         print(title)
        x, label = item['image'].to(device), item['label'].to(device)
        logits = model(x)
        loss = criteon(logits, label)
        train_loss.append(loss.item())

        # backporp
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print('epoch: ', epoch, ' [', idx, '/', len(data_loader1), '] ', 'loss: ', loss.item())

    print('epoch: ', epoch, 'loss: ', loss.item())

    model.eval()
    with torch.no_grad():
        # test
        total_correct = 0
        total_num = 0
        for idx, (x, label, title) in data_loader2:
            x, label = x.to(device), label.to(device)
            logits = model(x)
            pred = logits.argmax(dim=1)
            total_correct += torch.eq(pred, label).float().sum().item()
            total_num += x.size(0)

        acc = total_correct / total_num
        train_acc.append(acc)
        print('acc: ', acc, '\n')

    # 保存神经网络
    if epoch % 100 == 0:
        torch.save(model, 'net.pkl')                      # 保存整个神经网络的结构和模型参数
        print('saved in net.pkl')

plot_curve(train_loss, 'loss')
plot_curve(train_acc, 'acc')

# 保存神经网络
torch.save(model, 'net.pkl')                      # 保存整个神经网络的结构和模型参数
#     torch.save(net.state_dict(), 'net_params.pkl')  # 只保存神经网络的模型参数
print('saved in net.pkl')

epoch:  0  [ 0 / 278 ]  loss:  2.974290370941162
epoch:  0  [ 1 / 278 ]  loss:  6.352252006530762
epoch:  0  [ 2 / 278 ]  loss:  3.431450128555298
epoch:  0  [ 3 / 278 ]  loss:  10.02437973022461
epoch:  0  [ 4 / 278 ]  loss:  2.7245047092437744
epoch:  0  [ 5 / 278 ]  loss:  7.656571388244629
epoch:  0  [ 6 / 278 ]  loss:  8.653200149536133
epoch:  0  [ 7 / 278 ]  loss:  5.788845062255859
epoch:  0  [ 8 / 278 ]  loss:  6.65566873550415
epoch:  0  [ 9 / 278 ]  loss:  3.440462112426758
epoch:  0  [ 10 / 278 ]  loss:  2.7505438327789307
epoch:  0  [ 11 / 278 ]  loss:  3.4973886013031006
epoch:  0  [ 12 / 278 ]  loss:  3.2377116680145264
epoch:  0  [ 13 / 278 ]  loss:  5.8137969970703125
epoch:  0  [ 14 / 278 ]  loss:  3.0313119888305664
epoch:  0  [ 15 / 278 ]  loss:  4.761716365814209
epoch:  0  [ 16 / 278 ]  loss:  2.778014898300171
epoch:  0  [ 17 / 278 ]  loss:  2.913270950317383
epoch:  0  [ 18 / 278 ]  loss:  4.094699859619141
epoch:  0  [ 19 / 278 ]  loss:  3.3849384784698486
epoc

  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))
  " Skipping tag %s" % (size, len(data), tag))


epoch:  0  [ 73 / 278 ]  loss:  2.3338358402252197
epoch:  0  [ 74 / 278 ]  loss:  2.488616466522217
epoch:  0  [ 75 / 278 ]  loss:  2.4744362831115723
epoch:  0  [ 76 / 278 ]  loss:  2.435516357421875
epoch:  0  [ 77 / 278 ]  loss:  2.5986361503601074
epoch:  0  [ 78 / 278 ]  loss:  2.659709930419922
epoch:  0  [ 79 / 278 ]  loss:  2.415632724761963
epoch:  0  [ 80 / 278 ]  loss:  2.2891411781311035
epoch:  0  [ 81 / 278 ]  loss:  2.170086145401001
epoch:  0  [ 82 / 278 ]  loss:  2.598045587539673
epoch:  0  [ 83 / 278 ]  loss:  2.837629556655884
epoch:  0  [ 84 / 278 ]  loss:  2.3145039081573486
epoch:  0  [ 85 / 278 ]  loss:  2.4555115699768066
epoch:  0  [ 86 / 278 ]  loss:  2.2946152687072754
epoch:  0  [ 87 / 278 ]  loss:  2.0851337909698486
epoch:  0  [ 88 / 278 ]  loss:  2.4678030014038086
epoch:  0  [ 89 / 278 ]  loss:  2.6255576610565186
epoch:  0  [ 90 / 278 ]  loss:  2.5258450508117676
epoch:  0  [ 91 / 278 ]  loss:  2.6164238452911377
epoch:  0  [ 92 / 278 ]  loss:  2.3752

  " Skipping tag %s" % (size, len(data), tag))


epoch:  0  [ 148 / 278 ]  loss:  2.24196195602417
epoch:  0  [ 149 / 278 ]  loss:  2.283681869506836
epoch:  0  [ 150 / 278 ]  loss:  2.483147144317627
epoch:  0  [ 151 / 278 ]  loss:  2.5437941551208496
epoch:  0  [ 152 / 278 ]  loss:  2.599973440170288
epoch:  0  [ 153 / 278 ]  loss:  2.669032335281372
epoch:  0  [ 154 / 278 ]  loss:  2.9106225967407227
epoch:  0  [ 155 / 278 ]  loss:  3.263371229171753
epoch:  0  [ 156 / 278 ]  loss:  2.2998735904693604
epoch:  0  [ 157 / 278 ]  loss:  1.9890421628952026
epoch:  0  [ 158 / 278 ]  loss:  2.65230393409729
epoch:  0  [ 159 / 278 ]  loss:  2.758401870727539
epoch:  0  [ 160 / 278 ]  loss:  2.439993381500244
epoch:  0  [ 161 / 278 ]  loss:  2.1920032501220703
epoch:  0  [ 162 / 278 ]  loss:  2.34610915184021
epoch:  0  [ 163 / 278 ]  loss:  2.3474650382995605
epoch:  0  [ 164 / 278 ]  loss:  2.2968344688415527
epoch:  0  [ 165 / 278 ]  loss:  1.8239208459854126
epoch:  0  [ 166 / 278 ]  loss:  2.1174328327178955
epoch:  0  [ 167 / 278 ] 

KeyboardInterrupt: 