Skip to content

Commit

Permalink
Adding loss to catalog
Browse files Browse the repository at this point in the history
  • Loading branch information
kai-polsterer committed Aug 16, 2023
1 parent e558542 commit 26ac230
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
18 changes: 10 additions & 8 deletions hipster.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ def check_folders(self, base_folder):
Args:
base_folder (String): The base folder to check.
"""
if os.path.exists(os.path.join(self.output_folder, self.title, base_folder)):
answer = input("path exists, delete? Yes,[No]")
path = os.path.join(self.output_folder, self.title, base_folder)
if os.path.exists(path):
answer = input("path "+str(path)+", delete? Yes,[No]")
if answer == "Yes":
rmtree(os.path.join(self.output_folder, self.title, base_folder))
else:
Expand Down Expand Up @@ -170,7 +171,7 @@ def generate_hips(self, model):

print("creating tiles:")
for i in range(self.max_order+1):
print ("\n order "+str(i)+" ["+
print (" order "+str(i)+" ["+
str(12*4**i).rjust(int(math.log10(12*4**self.max_order))+1," ")+" tiles]:",
end="")
for j in range(12*4**i):
Expand Down Expand Up @@ -220,7 +221,7 @@ def generate_catalog(self, model, dataloader, catalog_file):
if answer != "Yes":
return
print("projecting dataset:")
coordinates, rotations = model.project_dataset(dataloader, 36)
coordinates, rotations, losses = model.project_dataset(dataloader, 36)
coordinates = coordinates.cpu().detach().numpy()
rotations = rotations.cpu().detach().numpy()
angles = numpy.array(healpy.vec2ang(coordinates))*180.0/math.pi
Expand All @@ -230,14 +231,15 @@ def generate_catalog(self, model, dataloader, catalog_file):
with open(os.path.join(self.output_folder,
self.title,
"catalog.csv"), 'w', encoding="utf-8") as output:
output.write("#id,RA2000,DEC2000,rotation,x,y,z,pix3,pix4,filename\n")
output.write("#id,RA2000,DEC2000,rotation,x,y,z,loss,filename\n")
for i in range(coordinates.shape[0]):
output.write(str(i)+","+str(angles[i,1])+"," +
str(90.0-angles[i,0])+"," +
str(rotations[i])+",")
output.write(str(coordinates[i,0])+"," +
str(coordinates[i,1])+"," +
str(coordinates[i,2])+",")
output.write(str(losses[i])+",")
output.write("http://localhost:8083" +
dataloader.dataset[i]['filename']+"\n")
output.flush()
Expand Down Expand Up @@ -314,7 +316,7 @@ def generate_dataset_projection(self, dataset, catalog_file):
print("done!")

if __name__ == "__main__":
myHipster = Hipster("HiPSter", "GZ", max_order=5, crop_size=256, output_size=64)
myHipster = Hipster("/hits/basement/ain/Data/HiPSter", "GZ", max_order=7, crop_size=256, output_size=128)
myModel = RotationalSphericalProjectingAutoencoder()
#checkpoint = torch.load("efigi_epoch41835-step753048.ckpt")
checkpoint = torch.load("gz_epoch4523-step1090284.ckpt")
Expand All @@ -332,8 +334,8 @@ def generate_dataset_projection(self, dataset, catalog_file):

myDataloader = DataLoader(myDataset, batch_size=1024, shuffle=False, num_workers=16)

myHipster.generate_catalog(myModel, myDataloader, "catalog.csv")
#myHipster.generate_catalog(myModel, myDataloader, "catalog.csv")

myHipster.generate_dataset_projection(myDataset, "catalog.csv")
#myHipster.generate_dataset_projection(myDataset, "catalog.csv")

#TODO: currently you manually have to call 'python3 -m http.server 8082' to start a simple web server providing access to the tiles.
4 changes: 3 additions & 1 deletion models/RotationalSphericalProjectingAutoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def training_step(self, train_batch, batch_idx):
def project_dataset(self, dataloader, rotation_steps):
result_coordinates = torch.zeros((0, 3))
result_rotations = torch.zeros((0))
result_losses = torch.zeros((0))
for batch in dataloader:
print(".", end="")
losses = torch.zeros((batch['id'].shape[0],rotation_steps))
Expand All @@ -98,8 +99,9 @@ def project_dataset(self, dataloader, rotation_steps):
min = torch.argmin(losses, dim=1)
result_coordinates = torch.cat((result_coordinates, coords[torch.arange(batch['id'].shape[0]),min]))
result_rotations = torch.cat((result_rotations, 360.0/rotation_steps*min))
result_losses = torch.cat((result_losses, losses[torch.arange(batch['id'].shape[0]),min]))
del losses
del coords
del min
gc.collect()
return result_coordinates, result_rotations
return result_coordinates, result_rotations, result_losses

0 comments on commit 26ac230

Please sign in to comment.