Skip to content

Commit

Permalink
Worked on issue #872: Added script for running python files in exampl…
Browse files Browse the repository at this point in the history
…es/classification_3d and examples/classification_3d_ignite folders
  • Loading branch information
arp95 committed Aug 12, 2020
1 parent 745c940 commit 3103deb
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 136 deletions.
24 changes: 12 additions & 12 deletions examples/classification_3d/densenet_evaluation_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ def main():

# IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
images = [
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI607-Guys-1097-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI175-HH-1570-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI385-HH-2078-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI344-Guys-0905-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI409-Guys-0960-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI584-Guys-1129-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI253-HH-1694-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI092-HH-1436-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI574-IOP-1156-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI585-Guys-1130-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI607-Guys-1097-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI175-HH-1570-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI385-HH-2078-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI344-Guys-0905-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI409-Guys-0960-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI584-Guys-1129-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI253-HH-1694-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI092-HH-1436-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI574-IOP-1156-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI585-Guys-1130-T1.nii.gz"]),
]

# 2 binary labels for gender classification: man and woman
Expand All @@ -52,10 +52,10 @@ def main():
val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available())

# Create DenseNet121
device = torch.device("cuda:0")
device = torch.device("cpu")
model = monai.networks.nets.densenet.densenet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)

model.load_state_dict(torch.load("best_metric_model.pth"))
model.load_state_dict(torch.load("best_metric_model_1.pth"))
model.eval()
with torch.no_grad():
num_correct = 0.0
Expand Down
24 changes: 12 additions & 12 deletions examples/classification_3d/densenet_evaluation_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ def main():

# IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
images = [
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI607-Guys-1097-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI175-HH-1570-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI385-HH-2078-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI344-Guys-0905-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI409-Guys-0960-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI584-Guys-1129-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI253-HH-1694-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI092-HH-1436-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI574-IOP-1156-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI585-Guys-1130-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI607-Guys-1097-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI175-HH-1570-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI385-HH-2078-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI344-Guys-0905-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI409-Guys-0960-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI584-Guys-1129-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI253-HH-1694-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI092-HH-1436-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI574-IOP-1156-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI585-Guys-1130-T1.nii.gz"]),
]

# 2 binary labels for gender classification: man and woman
Expand All @@ -60,10 +60,10 @@ def main():
val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available())

# Create DenseNet121
device = torch.device("cuda:0")
device = torch.device("cpu")
model = monai.networks.nets.densenet.densenet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)

model.load_state_dict(torch.load("best_metric_model.pth"))
model.load_state_dict(torch.load("best_metric_model_2.pth"))
model.eval()
with torch.no_grad():
num_correct = 0.0
Expand Down
44 changes: 22 additions & 22 deletions examples/classification_3d/densenet_training_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,26 @@ def main():

# IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
images = [
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI314-IOP-0889-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI249-Guys-1072-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI609-HH-2600-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI173-HH-1590-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI020-Guys-0700-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI342-Guys-0909-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI134-Guys-0780-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI577-HH-2661-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI066-Guys-0731-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI130-HH-1528-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI607-Guys-1097-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI175-HH-1570-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI385-HH-2078-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI344-Guys-0905-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI409-Guys-0960-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI584-Guys-1129-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI253-HH-1694-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI092-HH-1436-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI574-IOP-1156-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI585-Guys-1130-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI314-IOP-0889-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI249-Guys-1072-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI609-HH-2600-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI173-HH-1590-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI020-Guys-0700-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI342-Guys-0909-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI134-Guys-0780-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI577-HH-2661-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI066-Guys-0731-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI130-HH-1528-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI607-Guys-1097-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI175-HH-1570-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI385-HH-2078-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI344-Guys-0905-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI409-Guys-0960-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI584-Guys-1129-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI253-HH-1694-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI092-HH-1436-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI574-IOP-1156-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI585-Guys-1130-T1.nii.gz"]),
]

# 2 binary labels for gender classification: man and woman
Expand All @@ -73,7 +73,7 @@ def main():
val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available())

# Create DenseNet121, CrossEntropyLoss and Adam optimizer
device = torch.device("cuda:0")
device = torch.device("cpu")
model = monai.networks.nets.densenet.densenet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
Expand Down Expand Up @@ -123,7 +123,7 @@ def main():
if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), "best_metric_model.pth")
torch.save(model.state_dict(), "best_metric_model_1.pth")
print("saved new best metric model")
print(
"current epoch: {} current accuracy: {:.4f} best accuracy: {:.4f} at epoch {}".format(
Expand Down
44 changes: 22 additions & 22 deletions examples/classification_3d/densenet_training_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,26 @@ def main():

# IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
images = [
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI314-IOP-0889-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI249-Guys-1072-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI609-HH-2600-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI173-HH-1590-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI020-Guys-0700-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI342-Guys-0909-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI134-Guys-0780-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI577-HH-2661-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI066-Guys-0731-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI130-HH-1528-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI607-Guys-1097-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI175-HH-1570-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI385-HH-2078-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI344-Guys-0905-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI409-Guys-0960-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI584-Guys-1129-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI253-HH-1694-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI092-HH-1436-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI574-IOP-1156-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI585-Guys-1130-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI314-IOP-0889-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI249-Guys-1072-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI609-HH-2600-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI173-HH-1590-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI020-Guys-0700-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI342-Guys-0909-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI134-Guys-0780-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI577-HH-2661-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI066-Guys-0731-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI130-HH-1528-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI607-Guys-1097-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI175-HH-1570-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI385-HH-2078-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI344-Guys-0905-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI409-Guys-0960-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI584-Guys-1129-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI253-HH-1694-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI092-HH-1436-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI574-IOP-1156-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI585-Guys-1130-T1.nii.gz"]),
]

# 2 binary labels for gender classification: man and woman
Expand Down Expand Up @@ -92,7 +92,7 @@ def main():
val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available())

# Create DenseNet121, CrossEntropyLoss and Adam optimizer
device = torch.device("cuda:0")
device = torch.device("cpu")
model = monai.networks.nets.densenet.densenet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
Expand Down Expand Up @@ -139,7 +139,7 @@ def main():
if acc_metric > best_metric:
best_metric = acc_metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), "best_metric_model.pth")
torch.save(model.state_dict(), "best_metric_model_2.pth")
print("saved new best metric model")
print(
"current epoch: {} current accuracy: {:.4f} current AUC: {:.4f} best accuracy: {:.4f} at epoch {}".format(
Expand Down
24 changes: 12 additions & 12 deletions examples/classification_3d_ignite/densenet_evaluation_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@ def main():

# IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
images = [
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI607-Guys-1097-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI175-HH-1570-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI385-HH-2078-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI344-Guys-0905-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI409-Guys-0960-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI584-Guys-1129-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI253-HH-1694-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI092-HH-1436-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI574-IOP-1156-T1.nii.gz"]),
os.sep.join(["workspace", "data", "medical", "ixi", "IXI-T1", "IXI585-Guys-1130-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI607-Guys-1097-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI175-HH-1570-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI385-HH-2078-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI344-Guys-0905-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI409-Guys-0960-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI584-Guys-1129-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI253-HH-1694-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI092-HH-1436-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI574-IOP-1156-T1.nii.gz"]),
os.sep.join(["./workspace", "data", "medical", "ixi", "IXI-T1", "IXI585-Guys-1130-T1.nii.gz"]),
]

# 2 binary labels for gender classification: man and woman
Expand All @@ -52,7 +52,7 @@ def main():
val_ds = NiftiDataset(image_files=images, labels=labels, transform=val_transforms, image_only=False)
# create DenseNet121
net = monai.networks.nets.densenet.densenet121(spatial_dims=3, in_channels=1, out_channels=2)
device = torch.device("cuda:0")
device = torch.device("cpu")

metric_name = "Accuracy"
# add evaluation metric to the evaluator engine
Expand Down Expand Up @@ -81,7 +81,7 @@ def prepare_batch(batch, device=None, non_blocking=False):
prediction_saver.attach(evaluator)

# the model was trained by "densenet_training_array" example
CheckpointLoader(load_path="./runs/net_checkpoint_20.pth", load_dict={"net": net}).attach(evaluator)
CheckpointLoader(load_path="./runs_array/net_checkpoint_20.pth", load_dict={"net": net}).attach(evaluator)

# create a validation data loader
val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available())
Expand Down
Loading

0 comments on commit 3103deb

Please sign in to comment.