diff --git a/agent.py b/agent.py index a2a5051..92a2990 100644 --- a/agent.py +++ b/agent.py @@ -21,9 +21,15 @@ def __init__(self, args, env): self.discount = args.discount self.online_net = DQN(args, self.action_space).to(device=args.device) - if args.model and os.path.isfile(args.model): - # Always load tensors onto CPU by default, will shift to GPU if necessary - self.online_net.load_state_dict(torch.load(args.model, map_location='cpu')) + # Load model if provided (raises exception for incorrect model path) + if args.model: + if os.path.isfile(args.model) + # Always load tensors onto CPU by default, will shift to GPU if necessary + self.online_net.load_state_dict(torch.load(args.model, map_location='cpu')) + print("Loading pretrained model: " + args.model) + else: + raise FileNotFoundError(args.model) + self.online_net.train() self.target_net = DQN(args, self.action_space).to(device=args.device)