Skip to content

Commit

Permalink
build: train seaplane
Browse files Browse the repository at this point in the history
  • Loading branch information
beiyuouo committed May 1, 2022
1 parent f5c1cec commit 449442a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Expand Up @@ -127,3 +127,8 @@ dmypy.json

# Pyre type checker
.pyre/

data/

*.pth
*.onnx
10 changes: 6 additions & 4 deletions src/config.py
Expand Up @@ -17,8 +17,9 @@ def mkdir(path, remove=False):


class config:
remv = True
data_path = os.path.join('..', 'data')
remv = False
task_name = 'seaplane'
data_path = os.path.join('..', 'data', task_name)
labeled_data_path = os.path.join(data_path, 'labeled')

mkdir(data_path)
Expand All @@ -33,8 +34,9 @@ class config:
label_file_path = os.path.join(labeled_data_path, 'label.txt')

lr = 0.005
epochs = 20
batch_size = 1
epochs = 100
batch_size = 16
log_interval = 100

img_transform = torchvision.transforms.Compose([
# torchvision.transforms.Grayscale(num_output_channels=1),
Expand Down
3 changes: 2 additions & 1 deletion src/main.py
Expand Up @@ -42,7 +42,7 @@ def train():
loss = criterion(out, label)
loss.backward()
optimizer.step()
if (i + 1) % 10 == 0:
if (i + 1) % config.log_interval == 0:
print(f'epoch: {epoch + 1}, iter: {i + 1}, loss: {loss.item():.4f}')
total_loss += loss.item()
total_acc += torch.sum(torch.argmax(out, dim=1) == label).item()
Expand All @@ -51,6 +51,7 @@ def train():
)

torch.save(model.state_dict(), config.model_path)
model = model.cpu()
model.eval()
torch.onnx.export(model,
torch.randn(1, 3, 64, 64),
Expand Down

0 comments on commit 449442a

Please sign in to comment.