diff --git a/.gitignore b/.gitignore index b6e47617..00b5348c 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,8 @@ dmypy.json # Pyre type checker .pyre/ + +data/ + +*.pth +*.onnx \ No newline at end of file diff --git a/src/config.py b/src/config.py index 7d07fc0e..4ab4db54 100644 --- a/src/config.py +++ b/src/config.py @@ -17,8 +17,9 @@ def mkdir(path, remove=False): class config: - remv = True - data_path = os.path.join('..', 'data') + remv = False + task_name = 'seaplane' + data_path = os.path.join('..', 'data', task_name) labeled_data_path = os.path.join(data_path, 'labeled') mkdir(data_path) @@ -33,8 +34,9 @@ class config: label_file_path = os.path.join(labeled_data_path, 'label.txt') lr = 0.005 - epochs = 20 - batch_size = 1 + epochs = 100 + batch_size = 16 + log_interval = 100 img_transform = torchvision.transforms.Compose([ # torchvision.transforms.Grayscale(num_output_channels=1), diff --git a/src/main.py b/src/main.py index 94c23ceb..b6bf2cd4 100644 --- a/src/main.py +++ b/src/main.py @@ -42,7 +42,7 @@ def train(): loss = criterion(out, label) loss.backward() optimizer.step() - if (i + 1) % 10 == 0: + if (i + 1) % config.log_interval == 0: print(f'epoch: {epoch + 1}, iter: {i + 1}, loss: {loss.item():.4f}') total_loss += loss.item() total_acc += torch.sum(torch.argmax(out, dim=1) == label).item() @@ -51,6 +51,7 @@ def train(): ) torch.save(model.state_dict(), config.model_path) + model = model.cpu() model.eval() torch.onnx.export(model, torch.randn(1, 3, 64, 64),