In [6]:
import  torch
from    torch import optim, nn
import  visdom
import  torchvision
from    torch.utils.data import DataLoader

from    resnet18 import ResNet18
#from    densenet import DenseNet

In [7]:
from    stone import Stone
from    stone1 import Stone1

In [8]:
batchsz = 7
lr = 1e-3
epochs = 30

device = torch.device('cuda')
torch.manual_seed(1234)


train_db = Stone('stone', 224, mode='train')
val_db = Stone('stone', 224, mode='val')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True)
val_loader = DataLoader(val_db, batch_size=batchsz)





In [9]:
test_db = Stone1('stone', 224, mode='test')
test_loader = DataLoader(test_db, batch_size=batchsz,shuffle=False)



viz = visdom.Visdom()

Setting up a new session...


In [10]:
def evalute(model, loader):
    model.eval() #实际估计时使用
    
    correct = 0
    cc=0
    total = len(loader.dataset)

    for x,y in loader:
        x,y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()

    return correct / total


In [11]:
def main():

    model = ResNet18(7).to(device)
    #model = DenseNet(growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64,
    #             bn_size=4, compression_rate=0.5, drop_rate=0, num_classes=7).to(device)
    
    #model.load_state_dict(torch.load('best.mdl'))
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()


    best_acc, best_epoch = 0, 0
    global_step = 0
    #viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    #viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
    viz.line([0], [0], win='损失loss', opts=dict(title='损失loss'))
    viz.line([0], [0], win='交叉验证集测试结果', opts=dict(title='交叉验证集测试结果'))
    viz.line([0], [0], win='交叉验证集的测试结果', opts=dict(title='交叉验证集的测试结果'))
    for epoch in range(epochs):

        for step, (x,y) in enumerate(train_loader):

            # x: [b, 3, 224, 224], y: [b]
            x, y = x.to(device), y.to(device)
            
            model.train()
            logits = model(x)
            loss = criteon(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            #viz.line([loss.item()], [global_step], win='loss', update='append')
            viz.line([loss.item()], [global_step], win='损失loss', update='append')
            global_step += 1
            print('global_step:', global_step,'loss.item:',loss.item())

        if epoch % 1 == 0:

            val_acc = evalute(model, val_loader)
            
            print('val_acc:', val_acc, 'epoch:', epoch)


            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc

                torch.save(model.state_dict(), 'best.mdl')
                viz.line([val_acc], [epoch+1], win='交叉验证集测试结果', update='append')
                
            #viz.line([val_acc], [global_step], win='val_acc', update='append')
            viz.line([val_acc], [epoch+1], win='交叉验证集的测试结果', update='append')


                
                
    print('best acc:', best_acc, 'best epoch:', best_epoch)

    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')

    test_acc = evalute(model, test_loader)
    print('test acc:', test_acc)

In [None]:
if __name__ == '__main__':
    main()

global_step: 1 loss.item: 2.3074634075164795
global_step: 2 loss.item: 17.286062240600586
global_step: 3 loss.item: 2.745180368423462
global_step: 4 loss.item: 1.0308853387832642
global_step: 5 loss.item: 3.2761521339416504
global_step: 6 loss.item: 9.268641471862793
global_step: 7 loss.item: 7.042732238769531
global_step: 8 loss.item: 3.8830296993255615
val_acc: 0.14285714285714285 epoch: 0
global_step: 9 loss.item: 14.473143577575684
global_step: 10 loss.item: 16.312253952026367
global_step: 11 loss.item: 10.252202987670898
global_step: 12 loss.item: 8.386587142944336
global_step: 13 loss.item: 2.464073657989502
global_step: 14 loss.item: 15.272127151489258
global_step: 15 loss.item: 8.305219650268555
global_step: 16 loss.item: 1.8907034397125244
val_acc: 0.2857142857142857 epoch: 1
global_step: 17 loss.item: 2.42449688911438
global_step: 18 loss.item: 1.7587850093841553
global_step: 19 loss.item: 1.4625808000564575
global_step: 20 loss.item: 24.406875610351562
global_step: 21 loss.i

global_step: 163 loss.item: 0.06507108360528946
global_step: 164 loss.item: 0.3193051517009735
global_step: 165 loss.item: 0.36764636635780334
global_step: 166 loss.item: 1.9448939561843872
global_step: 167 loss.item: 0.41164737939834595
global_step: 168 loss.item: 0.23092077672481537
val_acc: 0.2857142857142857 epoch: 20
global_step: 169 loss.item: 0.16702429950237274
global_step: 170 loss.item: 0.6847673654556274
global_step: 171 loss.item: 1.9854145050048828
global_step: 172 loss.item: 0.12144631892442703
global_step: 173 loss.item: 0.10228509455919266
global_step: 174 loss.item: 0.06797502934932709
global_step: 175 loss.item: 0.47763094305992126
global_step: 176 loss.item: 0.6124639511108398
val_acc: 0.14285714285714285 epoch: 21
global_step: 177 loss.item: 1.981806993484497
global_step: 178 loss.item: 0.8466750979423523
global_step: 179 loss.item: 0.2549563944339752
global_step: 180 loss.item: 0.14082762598991394
global_step: 181 loss.item: 5.9217329025268555
global_step: 182 loss

In [None]:
model = ResNet18(7).to(device)
model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt!')

In [None]:
test_acc = evalute(model, test_loader)
print('test acc:', test_acc)