/
main.py
94 lines (76 loc) 路 2.53 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import cv2
from PIL import Image
from config import config
from nn.resnet_mini import ResNetMini
def train():
model = ResNetMini(3, 2)
model.train()
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
criterion = nn.CrossEntropyLoss()
print('model:', model)
data = torchvision.datasets.ImageFolder(config.labeled_data_path,
transform=config.img_transform)
data_loader = torch.utils.data.DataLoader(data, batch_size=config.batch_size, shuffle=True)
print(f'{len(data)} images')
epochs = config.epochs
# train with focal loss
for epoch in range(epochs):
total_loss = 0
total_acc = 0
for i, (img, label) in enumerate(data_loader):
img = img.cuda()
label = label.cuda()
optimizer.zero_grad()
out = model(img)
loss = criterion(out, label)
loss.backward()
optimizer.step()
if (i + 1) % config.log_interval == 0:
print(f'epoch: {epoch + 1}, iter: {i + 1}, loss: {loss.item():.4f}')
total_loss += loss.item()
total_acc += torch.sum(torch.argmax(out, dim=1) == label).item()
print(
f'epoch: {epoch + 1}, avg loss: {total_loss / len(data):.4f}, avg acc: {total_acc / len(data):.4f}'
)
torch.save(model.state_dict(), config.model_path)
model = model.cpu()
model.eval()
torch.onnx.export(model,
torch.randn(1, 3, 64, 64),
config.model_onnx_path,
verbose=True,
export_params=True)
def test_single(model, img):
img = config.img_transform(img)
img = img.unsqueeze(0)
img = img.cuda()
out = model(img)
pred = torch.argmax(out, dim=1)
# print(f'pred: {pred.item()}')
if pred.item() == 0:
return 0
else:
return 1
def test():
model = ResNetMini(3, 2)
def transfer_model():
model = ResNetMini(3, 2)
model.load_state_dict(torch.load(config.model_path))
model.eval()
torch.onnx.export(model,
torch.randn(1, 3, 64, 64),
config.model_onnx_path,
verbose=False,
export_params=True)
if __name__ == '__main__':
train()
transfer_model()