In [1]:
%load_ext autoreload
%autoreload 2

import torch
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from tqdm.notebook import tqdm
from utils.model import Resnet
from utils.dataset import *



In [2]:
full_data = Image_dataset(r'D:\Folder\Vscode\Git\Example\Data\Cassava_leaf_disease_classification',mode='train')
batch_size = 12
train_data,val_data,test_data = random_split(full_data,[0.8,0.15,0.05])
train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
val_loader = DataLoader(val_data,batch_size=batch_size,shuffle=False)
test_loader = DataLoader(test_data,batch_size=batch_size,shuffle=False)


device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Resnet()
model = model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)

params_num = 0
for i in model.parameters():
    params_num += i.numel()
print('参数量:',params_num)

def val_acc():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in val_loader:
            inputs,labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _,predicted = torch.max(outputs.data,1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct/total

参数量: 404480


In [3]:
epochs = 10
best_acc = 0
for epoch in range(epochs):
    accuracy = 0
    model.train()
    for step,(image,label) in enumerate(tqdm(train_loader)):
        image = image.to(device)
        label = label.to(device)
        output = model(image)
        loss = criterion(output,label)
        _,predicted = torch.max(output.data,1)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        accuracy = accuracy + (predicted == label).sum().item()
        print(f'Step: {step}, Loss: {loss.item()}, Accuracy: {accuracy/((step+1)*batch_size)}')
    val_accuracy = val_acc()
    print(f'Epoch: {epoch}, Val_Accuracy: {val_accuracy}')
    if val_accuracy > best_acc:
        best_acc = val_accuracy
        torch.save(model.state_dict(),f'model{val_accuracy:.2f}.pt')

  0%|          | 0/1427 [00:00<?, ?it/s]

Step: 0, Loss: 1.5580545663833618, Accuracy: 0.3333333333333333
Step: 1, Loss: 1.4920635223388672, Accuracy: 0.4166666666666667
Step: 2, Loss: 1.5511740446090698, Accuracy: 0.3888888888888889
Step: 3, Loss: 1.466092586517334, Accuracy: 0.375
Step: 4, Loss: 1.578878402709961, Accuracy: 0.35
Step: 5, Loss: 1.5488241910934448, Accuracy: 0.3611111111111111
Step: 6, Loss: 1.4900994300842285, Accuracy: 0.38095238095238093
Step: 7, Loss: 1.409048080444336, Accuracy: 0.40625
Step: 8, Loss: 1.562271237373352, Accuracy: 0.4074074074074074
Step: 9, Loss: 1.3928823471069336, Accuracy: 0.43333333333333335
Step: 10, Loss: 1.4524726867675781, Accuracy: 0.45454545454545453
Step: 11, Loss: 1.4625669717788696, Accuracy: 0.4583333333333333
Step: 12, Loss: 1.4560036659240723, Accuracy: 0.46153846153846156
Step: 13, Loss: 1.4997342824935913, Accuracy: 0.4642857142857143
Step: 14, Loss: 1.6117496490478516, Accuracy: 0.45
Step: 15, Loss: 1.459421157836914, Accuracy: 0.4583333333333333
Step: 16, Loss: 1.42627

  0%|          | 0/1427 [00:00<?, ?it/s]

Step: 0, Loss: 1.188383936882019, Accuracy: 0.75
Step: 1, Loss: 1.0209969282150269, Accuracy: 0.8333333333333334
Step: 2, Loss: 1.4193825721740723, Accuracy: 0.7222222222222222
Step: 3, Loss: 1.295035481452942, Accuracy: 0.6875
Step: 4, Loss: 1.1507552862167358, Accuracy: 0.7
Step: 5, Loss: 1.1542855501174927, Accuracy: 0.7083333333333334
Step: 6, Loss: 1.2335125207901, Accuracy: 0.7023809523809523
Step: 7, Loss: 1.1537567377090454, Accuracy: 0.7083333333333334
Step: 8, Loss: 1.2135430574417114, Accuracy: 0.7037037037037037
Step: 9, Loss: 1.3749476671218872, Accuracy: 0.6833333333333333
Step: 10, Loss: 1.3811321258544922, Accuracy: 0.6666666666666666
Step: 11, Loss: 1.359480381011963, Accuracy: 0.6527777777777778
Step: 12, Loss: 1.21206796169281, Accuracy: 0.6602564102564102
Step: 13, Loss: 1.4647070169448853, Accuracy: 0.6428571428571429
Step: 14, Loss: 1.3950835466384888, Accuracy: 0.6333333333333333
Step: 15, Loss: 1.1963226795196533, Accuracy: 0.6354166666666666
Step: 16, Loss: 1.0

  0%|          | 0/1427 [00:00<?, ?it/s]

Step: 0, Loss: 1.1828513145446777, Accuracy: 0.6666666666666666
Step: 1, Loss: 1.222298264503479, Accuracy: 0.6666666666666666
Step: 2, Loss: 1.3607488870620728, Accuracy: 0.6111111111111112
Step: 3, Loss: 1.0309314727783203, Accuracy: 0.6875
Step: 4, Loss: 1.2319201231002808, Accuracy: 0.6833333333333333
Step: 5, Loss: 1.2099920511245728, Accuracy: 0.6944444444444444
Step: 6, Loss: 1.2119311094284058, Accuracy: 0.6904761904761905
Step: 7, Loss: 1.6092419624328613, Accuracy: 0.6354166666666666
Step: 8, Loss: 1.3218647241592407, Accuracy: 0.6296296296296297
Step: 9, Loss: 1.5279854536056519, Accuracy: 0.6
Step: 10, Loss: 1.1661598682403564, Accuracy: 0.6136363636363636
Step: 11, Loss: 1.3133727312088013, Accuracy: 0.6111111111111112
Step: 12, Loss: 1.3659433126449585, Accuracy: 0.6089743589743589
Step: 13, Loss: 1.2697540521621704, Accuracy: 0.6130952380952381
Step: 14, Loss: 1.4977415800094604, Accuracy: 0.5944444444444444
Step: 15, Loss: 1.3720966577529907, Accuracy: 0.588541666666666

  0%|          | 0/1427 [00:00<?, ?it/s]

Step: 0, Loss: 1.4626669883728027, Accuracy: 0.4166666666666667
Step: 1, Loss: 1.1571117639541626, Accuracy: 0.5833333333333334
Step: 2, Loss: 1.3392633199691772, Accuracy: 0.5833333333333334
Step: 3, Loss: 1.2823463678359985, Accuracy: 0.5833333333333334
Step: 4, Loss: 1.2998541593551636, Accuracy: 0.5833333333333334
Step: 5, Loss: 1.0719369649887085, Accuracy: 0.625
Step: 6, Loss: 1.1058017015457153, Accuracy: 0.6547619047619048
Step: 7, Loss: 1.2479246854782104, Accuracy: 0.65625
Step: 8, Loss: 1.1681421995162964, Accuracy: 0.6574074074074074
Step: 9, Loss: 1.2543463706970215, Accuracy: 0.6583333333333333
Step: 10, Loss: 1.1273103952407837, Accuracy: 0.6666666666666666
Step: 11, Loss: 1.386893391609192, Accuracy: 0.6527777777777778
Step: 12, Loss: 1.130466103553772, Accuracy: 0.6602564102564102
Step: 13, Loss: 1.295440912246704, Accuracy: 0.6547619047619048
Step: 14, Loss: 1.171133041381836, Accuracy: 0.6611111111111111
Step: 15, Loss: 1.0983192920684814, Accuracy: 0.671875
Step: 16

  0%|          | 0/1427 [00:00<?, ?it/s]

Step: 0, Loss: 1.1177719831466675, Accuracy: 0.75
Step: 1, Loss: 1.1211992502212524, Accuracy: 0.75
Step: 2, Loss: 1.0784947872161865, Accuracy: 0.7777777777777778
Step: 3, Loss: 1.165144920349121, Accuracy: 0.7708333333333334
Step: 4, Loss: 1.0112468004226685, Accuracy: 0.8
Step: 5, Loss: 1.2200955152511597, Accuracy: 0.7777777777777778
Step: 6, Loss: 1.1527435779571533, Accuracy: 0.7738095238095238
Step: 7, Loss: 1.0668305158615112, Accuracy: 0.78125
Step: 8, Loss: 1.257228970527649, Accuracy: 0.7685185185185185
Step: 9, Loss: 1.220579981803894, Accuracy: 0.7583333333333333
Step: 10, Loss: 0.9840849041938782, Accuracy: 0.7727272727272727
Step: 11, Loss: 1.1655620336532593, Accuracy: 0.7708333333333334
Step: 12, Loss: 1.2156217098236084, Accuracy: 0.7628205128205128
Step: 13, Loss: 1.2668970823287964, Accuracy: 0.75
Step: 14, Loss: 1.2621673345565796, Accuracy: 0.7388888888888889
Step: 15, Loss: 1.2497659921646118, Accuracy: 0.734375
Step: 16, Loss: 1.2897368669509888, Accuracy: 0.725

  0%|          | 0/1427 [00:00<?, ?it/s]

Step: 0, Loss: 1.0416934490203857, Accuracy: 0.8333333333333334
Step: 1, Loss: 1.542259693145752, Accuracy: 0.5833333333333334
Step: 2, Loss: 1.0213700532913208, Accuracy: 0.6944444444444444
Step: 3, Loss: 1.1803377866744995, Accuracy: 0.7083333333333334
Step: 4, Loss: 1.367872714996338, Accuracy: 0.6666666666666666
Step: 5, Loss: 1.2079051733016968, Accuracy: 0.6805555555555556
Step: 6, Loss: 1.1822892427444458, Accuracy: 0.6904761904761905
Step: 7, Loss: 1.1402537822723389, Accuracy: 0.6979166666666666
Step: 8, Loss: 1.0685665607452393, Accuracy: 0.7129629629629629
Step: 9, Loss: 1.3326153755187988, Accuracy: 0.7
Step: 10, Loss: 1.2680188417434692, Accuracy: 0.696969696969697
Step: 11, Loss: 1.3089183568954468, Accuracy: 0.6875
Step: 12, Loss: 1.0740231275558472, Accuracy: 0.6987179487179487
Step: 13, Loss: 1.413521409034729, Accuracy: 0.6845238095238095
Step: 14, Loss: 1.2514692544937134, Accuracy: 0.6833333333333333
Step: 15, Loss: 1.0955621004104614, Accuracy: 0.6927083333333334
S

  0%|          | 0/1427 [00:00<?, ?it/s]

Step: 0, Loss: 1.211787462234497, Accuracy: 0.6666666666666666
Step: 1, Loss: 1.2415269613265991, Accuracy: 0.6666666666666666
Step: 2, Loss: 1.118996500968933, Accuracy: 0.7222222222222222
Step: 3, Loss: 1.1647453308105469, Accuracy: 0.7291666666666666
Step: 4, Loss: 1.2293450832366943, Accuracy: 0.7166666666666667
Step: 5, Loss: 1.337175726890564, Accuracy: 0.6944444444444444
Step: 6, Loss: 1.1734076738357544, Accuracy: 0.6904761904761905
Step: 7, Loss: 1.230141520500183, Accuracy: 0.6875
Step: 8, Loss: 1.2219648361206055, Accuracy: 0.6851851851851852
Step: 9, Loss: 1.008162260055542, Accuracy: 0.7083333333333334
Step: 10, Loss: 1.2413091659545898, Accuracy: 0.7045454545454546
Step: 11, Loss: 1.071208119392395, Accuracy: 0.7152777777777778
Step: 12, Loss: 1.3923486471176147, Accuracy: 0.6987179487179487
Step: 13, Loss: 1.0842052698135376, Accuracy: 0.7083333333333334
Step: 14, Loss: 1.2075461149215698, Accuracy: 0.7111111111111111
Step: 15, Loss: 1.4806767702102661, Accuracy: 0.69270

  0%|          | 0/1427 [00:00<?, ?it/s]

Step: 0, Loss: 1.2283217906951904, Accuracy: 0.6666666666666666
Step: 1, Loss: 1.2300094366073608, Accuracy: 0.6666666666666666
Step: 2, Loss: 1.124159336090088, Accuracy: 0.7222222222222222
Step: 3, Loss: 1.401976466178894, Accuracy: 0.6666666666666666
Step: 4, Loss: 1.2035797834396362, Accuracy: 0.6833333333333333
Step: 5, Loss: 1.1496084928512573, Accuracy: 0.6944444444444444
Step: 6, Loss: 1.3124873638153076, Accuracy: 0.6785714285714286
Step: 7, Loss: 1.2981981039047241, Accuracy: 0.6666666666666666
Step: 8, Loss: 1.2547653913497925, Accuracy: 0.6666666666666666
Step: 9, Loss: 1.0964750051498413, Accuracy: 0.6833333333333333
Step: 10, Loss: 1.231866478919983, Accuracy: 0.6818181818181818
Step: 11, Loss: 1.2489761114120483, Accuracy: 0.6736111111111112
Step: 12, Loss: 1.0767155885696411, Accuracy: 0.6858974358974359
Step: 13, Loss: 1.3202800750732422, Accuracy: 0.6785714285714286
Step: 14, Loss: 1.153103232383728, Accuracy: 0.6833333333333333
Step: 15, Loss: 1.2873693704605103, Acc

  0%|          | 0/1427 [00:00<?, ?it/s]

Step: 0, Loss: 1.2590357065200806, Accuracy: 0.6666666666666666
Step: 1, Loss: 1.2605406045913696, Accuracy: 0.6666666666666666
Step: 2, Loss: 1.242194652557373, Accuracy: 0.6666666666666666
Step: 3, Loss: 1.2192707061767578, Accuracy: 0.6666666666666666
Step: 4, Loss: 1.2694450616836548, Accuracy: 0.6666666666666666
Step: 5, Loss: 1.197769045829773, Accuracy: 0.6805555555555556
Step: 6, Loss: 0.972682535648346, Accuracy: 0.7261904761904762
Step: 7, Loss: 1.2203296422958374, Accuracy: 0.71875
Step: 8, Loss: 1.1715563535690308, Accuracy: 0.7129629629629629
Step: 9, Loss: 1.1510672569274902, Accuracy: 0.7166666666666667
Step: 10, Loss: 0.969387948513031, Accuracy: 0.7424242424242424
Step: 11, Loss: 1.2534576654434204, Accuracy: 0.7361111111111112
Step: 12, Loss: 1.1673613786697388, Accuracy: 0.7371794871794872
Step: 13, Loss: 1.2661255598068237, Accuracy: 0.7321428571428571
Step: 14, Loss: 1.0375102758407593, Accuracy: 0.7388888888888889
Step: 15, Loss: 1.2148884534835815, Accuracy: 0.73

  0%|          | 0/1427 [00:00<?, ?it/s]

Step: 0, Loss: 1.3812702894210815, Accuracy: 0.5
Step: 1, Loss: 1.3099004030227661, Accuracy: 0.5416666666666666
Step: 2, Loss: 1.026016116142273, Accuracy: 0.6666666666666666
Step: 3, Loss: 1.0909024477005005, Accuracy: 0.7083333333333334
Step: 4, Loss: 1.0846024751663208, Accuracy: 0.7333333333333333
Step: 5, Loss: 1.1418368816375732, Accuracy: 0.7361111111111112
Step: 6, Loss: 1.1463398933410645, Accuracy: 0.7380952380952381
Step: 7, Loss: 1.3770416975021362, Accuracy: 0.7083333333333334
Step: 8, Loss: 1.0808333158493042, Accuracy: 0.7222222222222222
Step: 9, Loss: 1.1912815570831299, Accuracy: 0.725
Step: 10, Loss: 1.2467681169509888, Accuracy: 0.7196969696969697
Step: 11, Loss: 1.2459087371826172, Accuracy: 0.7152777777777778
Step: 12, Loss: 1.1899181604385376, Accuracy: 0.7115384615384616
Step: 13, Loss: 1.0105644464492798, Accuracy: 0.7261904761904762
Step: 14, Loss: 1.0872701406478882, Accuracy: 0.7277777777777777
Step: 15, Loss: 1.4670292139053345, Accuracy: 0.7083333333333334

In [3]:
image = train_loader.dataset[0][0].cuda()
output = model(image.unsqueeze(0))