In [1]:
import  torch
from    torch import optim, nn
import  visdom
import  torchvision
from    torch.utils.data import DataLoader

from    resnet18 import ResNet18
from    densenet import DenseNet

In [2]:
from    stone import Stone
from    stonetest import StoneTest

In [3]:
batchsz = 7
lr = 1e-3
epochs = 20

device = torch.device('cuda')
torch.manual_seed(1234)


train_db = Stone('stone', 224, mode='train')
val_db = Stone('stone', 224, mode='val')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True)
val_loader = DataLoader(val_db, batch_size=batchsz)





In [4]:
test_db = StoneTest('stonetest', 224, mode='test')
test_loader = DataLoader(test_db, batch_size=batchsz,shuffle=False)



viz = visdom.Visdom()

Setting up a new session...


In [5]:
def evalute(model, loader):
    model.eval() #实际估计时使用
    
    correct = 0
    cc=0
    total = len(loader.dataset)

    for x,y in loader:
        x,y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()

    return correct / total


In [6]:
def main():

    #model = ResNet18(7).to(device)
    model = DenseNet(growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,
                 bn_size=4, compression_rate=0.5, drop_rate=0, num_classes=7).to(device)
    
    #model.load_state_dict(torch.load('best.mdl'))
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()


    best_acc, best_epoch = 0, 0
    global_step = 0
    #viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    #viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
    viz.line([0], [0], win='损失loss', opts=dict(title='损失loss'))
    viz.line([0], [0], win='交叉验证集测试结果', opts=dict(title='交叉验证集测试结果'))
    viz.line([0], [0], win='交叉验证集的测试结果', opts=dict(title='交叉验证集的测试结果'))
    for epoch in range(epochs):

        for step, (x,y) in enumerate(train_loader):

            # x: [b, 3, 224, 224], y: [b]
            x, y = x.to(device), y.to(device)
            
            model.train()
            logits = model(x)
            loss = criteon(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            #viz.line([loss.item()], [global_step], win='loss', update='append')
            viz.line([loss.item()], [global_step], win='损失loss', update='append')
            global_step += 1
            #print('global_step:', global_step,'loss.item:',loss.item())
        print('global_step:', global_step,'loss.item:',loss.item())

        if epoch % 1 == 0:

            val_acc = evalute(model, val_loader)
            
            print('val_acc:', val_acc, 'epoch:', epoch)


            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc

                torch.save(model.state_dict(), 'best.mdl')
                viz.line([val_acc], [epoch+1], win='交叉验证集测试结果', update='append')
                
            #viz.line([val_acc], [global_step], win='val_acc', update='append')
            viz.line([val_acc], [epoch+1], win='交叉验证集的测试结果', update='append')


                
                
    print('best acc:', best_acc, 'best epoch:', best_epoch)

    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')

    test_acc = evalute(model, test_loader)
    print('test acc:', test_acc)

In [7]:
if __name__ == '__main__':
    main()

global_step: 306 loss.item: 1.693791389465332
val_acc: 0.4369747899159664 epoch: 0
global_step: 612 loss.item: 1.459078073501587
val_acc: 0.44677871148459386 epoch: 1
global_step: 918 loss.item: 1.386906385421753
val_acc: 0.6176470588235294 epoch: 2
global_step: 1224 loss.item: 0.783399224281311
val_acc: 0.5630252100840336 epoch: 3
global_step: 1530 loss.item: 0.6796776056289673
val_acc: 0.5672268907563025 epoch: 4
global_step: 1836 loss.item: 1.0090464353561401
val_acc: 0.5770308123249299 epoch: 5
global_step: 2142 loss.item: 1.7514547109603882
val_acc: 0.484593837535014 epoch: 6
global_step: 2448 loss.item: 1.2047598361968994
val_acc: 0.6372549019607843 epoch: 7
global_step: 2754 loss.item: 1.0750868320465088
val_acc: 0.19047619047619047 epoch: 8
global_step: 3060 loss.item: 0.9981099963188171
val_acc: 0.6078431372549019 epoch: 9
global_step: 3366 loss.item: 0.824989914894104
val_acc: 0.6218487394957983 epoch: 10
global_step: 3672 loss.item: 0.41418957710266113
val_acc: 0.42016806722

In [8]:
model = ResNet18(7).to(device)
#model.load_state_dict(torch.load('best.mdl'))
#model = DenseNet(growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,
#             bn_size=4, compression_rate=0.5, drop_rate=0, num_classes=7).to(device)
print('loaded from ckpt!')

loaded from ckpt!


In [9]:
test_acc = evalute(model, test_loader)
print('test acc:', test_acc)

test acc: 0.6428571428571429
