
# **Galaxy Zoo 2 classification notebook**

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Deyht/AI_astro_ED_AAIF/blob/main/practical_works/CNN/classification/gz2_classification/gz2_classification.ipynb)

---


### **CIANNA installation**

#### Query GPU allocation and properties

If nvidia-smi fail, it might indicate that you launched the colab session whithout GPU reservation.  
To change the type of reservation go to "Runtime"->"Change runtime type" and select "GPU" as your hardware accelerator.

In [None]:
%%shell

nvidia-smi

cd /content/

git clone https://github.com/NVIDIA/cuda-samples/

cd /content/cuda-samples/Samples/1_Utilities/deviceQuery/

cmake CMakeLists.txt

make SMS="50 60 70 80"

./deviceQuery | grep Capability | cut -c50- > ~/cuda_infos.txt
./deviceQuery | grep "CUDA Driver Version / Runtime Version" | cut -c57- >> ~/cuda_infos.txt

cd ~/

If you are granted a GPU that does not support FP16 computation, it is advised to change the mixed precision method to FP32C_FP32A in the corresponding cells.  
See the detail description on mixed precision support with CIANNA on the [Systeme Requirements](https://github.com/Deyht/CIANNA/wiki/1\)-System-Requirements) wiki page.

#### Clone CIANNA git repository

In [None]:
%%shell

cd /content/

git clone https://github.com/Deyht/CIANNA

cd CIANNA

#### Compiling CIANNA for the allocated GPU generation

There is no guaranteed forward or backward compatibility between Nvidia GPU generation, and some capabilities are generation specific. For these reasons, CIANNA must be provided the platform GPU generation at compile time.
The following cell will automatically update all the necessary files based on the detected GPU, and compile CIANNA.

In [None]:
%%shell

cd /content/CIANNA

mult="10"
cat ~/cuda_infos.txt
comp_cap="$(sed '1!d' ~/cuda_infos.txt)"
cuda_vers="$(sed '2!d' ~/cuda_infos.txt)"

lim="11.1"
old_arg=$(awk '{if ($1 < $2) print "-D CUDA_OLD";}' <<<"${cuda_vers} ${lim}")

sm_val=$(awk '{print $1*$2}' <<<"${mult} ${comp_cap}")

gen_val=$(awk '{if ($1 >= 80) print "-D GEN_AMPERE"; else if($1 >= 70) print "-D GEN_VOLTA";}' <<<"${sm_val}")

sed -i "s/.*arch=sm.*/\\t\tcuda_arg=\"\$cuda_arg -D CUDA -D comp_CUDA -lcublas -lcudart -arch=sm_$sm_val $old_arg $gen_val\"/g" compile.cp
sed -i "s/\/cuda-[0-9][0-9].[0-9]/\/cuda-$cuda_vers/g" compile.cp
sed -i "s/\/cuda-[0-9][0-9].[0-9]/\/cuda-$cuda_vers/g" src/python_module_setup.py

./compile.cp CUDA PY_INTERF

mv src/build/lib.linux-x86_64-* src/build/lib.linux-x86_64

#### CIANNA notebook guideline

**IMPORTANT NOTE**   
CIANNA is mainly used in a script fashion and was not designed to run in notebooks. Every cell code that directly invokes CIANNA functions must be run as a script to avoid possible errors.  
To do so, the cell must have the following structure.

```
%%shell

cd /content/CIANNA

python3 - <<EOF

[... your python code ...]

EOF
```

This syntax allows one to easily edit python code in the notebook while running the cell as a script. Note that all the notebook variables can not be accessed by the cell in this context.


### **Galaxy Zoo 2 classification**

In the original Galaxy Zoo project, volunteers classified images of Sloan Digital Sky Survey galaxies as belonging to one of six categories - elliptical, clockwise spiral, anticlockwise spiral, edge-on , star/don't know, or merger. GZ2 extends the original Galaxy Zoo classifications for a subsample of the brightest and largest galaxies in the Legacy release, measuring more detailed morphological features. This includes galactic bars, spiral arm and pitch angle, bulges, edge-on galaxies, relative ellipticities, and many others.

There are 243,434 images in total, all resized to a 424x424 resolution. Images are composed so the main object is centered and a part of the environment is visible. This implies that the FoV of each image is different.
For simplicity, we will use cropped and resized images that are more zoomed in toward the object and resized to a either a 64x64 or a 128x128 image resolution.

Details on the classification process can be found in [Hart et al. 2016](https://academic.oup.com/mnras/article/461/4/3663/2608720?login=true)


#### Downloading and visualizing the data

In [None]:
%%writefile /content/helper.py


import numpy as np
import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy import signal, ndimage
from PIL import Image
from PIL import ImageFilter

import albumentations as A
import cv2
import os,sys
import gc


image_size = 128
im_depth = 3


if(not os.path.isdir("gz2_images_%dx%d"%(image_size,image_size))):
	print("Downloading dataset ...")
	if(image_size == 64):
		os.system("wget https://share.obspm.fr/s/HJAXeYkiBsK4P8F/download/gz2_images_64x64.tar.gz")
		os.system("tar -xzf gz2_images_64x64.tar.gz")
	elif(image_size == 128):
		os.system("wget https://share.obspm.fr/s/fGRRQX6zao43mbE/download/gz2_images_128x128.tar.gz")
		os.system("tar -xzf gz2_images_128x128.tar.gz")
	else:
		print("Invalid image size ...")
		exit()

catalog_path = "gz2_filename_mapping_with_class.csv"

if(not os.path.isfile(catalog_path)):
		os.system("wget https://share.obspm.fr/s/fKqtaAQTQAt7k7J/download/gz2_filename_mapping_with_class.csv")


gz2_catalog_header = np.genfromtxt(catalog_path, delimiter=",", max_rows=1, dtype="str")

gz2_catalog = np.genfromtxt(catalog_path, delimiter=",", skip_header=1, dtype="str")

np.random.shuffle(gz2_catalog)

do_not_exist = np.zeros(np.shape(gz2_catalog), dtype="int")

for i in range(0, np.shape(gz2_catalog)[0]):
	path = "gz2_images_%dx%d/%s.jpg"%(image_size, image_size, gz2_catalog[i,1])
	if(not os.path.isfile(path)):
		do_not_exist[i] = 1

#Remove around 100 objects for which the image is not provided
index = np.where(do_not_exist == 1)[0]

gz2_catalog = np.delete(gz2_catalog, index, axis=0)

del(do_not_exist)
gc.collect()

total_nb_images = np.shape(gz2_catalog)[0]

class_list = ["A", "Ec", "Ei", "Er", "SBa", "SBb", "SBc", "SBd", "Sa", "Sb", "Sc", "Sd", "Se"]
class_list = ["E", "SB", "S"]

frac_test = 0.05

nb_class = int(len(class_list))
class_count_train = np.zeros(nb_class, dtype="int")
class_count_test = np.zeros(nb_class, dtype="int")

for i in range(0, nb_class):
	index = np.where(gz2_catalog[:,2].astype("<U%d"%(len(class_list[i]))) == class_list[i])[0]
	class_count_test[i] = int(frac_test*np.shape(index)[0])
	class_count_train[i] = np.shape(index)[0] - class_count_test[i]

total_nb_train = np.sum(class_count_train)
total_nb_test = np.sum(class_count_test)

filename_to_class_array_train = np.zeros((total_nb_train,3), dtype="int")
filename_to_class_array_test = np.zeros((total_nb_test,3), dtype="int")

train_cumsum = np.zeros(nb_class+1, dtype="int")
test_cumsum = np.zeros(nb_class+1, dtype="int")

train_cumsum[1:] = np.cumsum(class_count_train)
test_cumsum[1:] = np.cumsum(class_count_test)

for i in range(0, nb_class):
	index = np.where(gz2_catalog[:,2].astype("<U%d"%(len(class_list[i]))) == class_list[i])[0]

	filename_to_class_array_train[train_cumsum[i]:train_cumsum[i+1], 0] = gz2_catalog[index[:class_count_train[i]],1]
	filename_to_class_array_train[train_cumsum[i]:train_cumsum[i+1], 1] = i
	filename_to_class_array_test[test_cumsum[i]:test_cumsum[i+1], 0] = gz2_catalog[index[class_count_train[i]:],1]
	filename_to_class_array_test[test_cumsum[i]:test_cumsum[i+1], 1] = i

transform = A.Compose([
	A.HorizontalFlip(p=0.5),
	A.VerticalFlip(p=0.5),
])

training_rebalance = np.clip(class_count_train[:],0,4000).astype("float")
print(training_rebalance)
training_rebalance /= np.sum(training_rebalance)

nb_im_train = 2048
nb_im_test = total_nb_test

input_data = np.zeros((nb_im_train,image_size*image_size*(im_depth+1)), dtype="float32")
targets = np.zeros((nb_im_train,nb_class), dtype="float32")

input_test = np.zeros((nb_im_test,image_size*image_size*(im_depth+1)), dtype="float32")
targets_test = np.zeros((nb_im_test,nb_class), dtype="float32")

zero_target = np.zeros(nb_class)

def create_train_batch(visual=0):

	if(visual):
		fig, axs = plt.subplots(4,5, figsize=(5,4), dpi=250, constrained_layout=True)

	for i in range(0,nb_im_train):
		r_class = np.random.choice(np.arange(0,nb_class), p=training_rebalance)

		i_d = train_cumsum[r_class] + int(np.random.random()*class_count_train[r_class])
		path = "gz2_images_%dx%d/%d.jpg"%(image_size,image_size,filename_to_class_array_train[i_d,0])

		patch = np.asarray(Image.open(path))
		transformed = transform(image=patch)
		patch = (transformed['image']/255.0)

		for depth in range(0,im_depth):
			input_data[i,depth*image_size*image_size:(depth+1)*image_size*image_size] = (np.copy(patch[:,:,depth]).flatten("C"))
		input_data[i,3*image_size*image_size:] = 0.0
		targets[i,:] = 0
		targets[i,filename_to_class_array_train[i_d,1]] = 1

		if(visual and i < 20):
			axs[i//5][i%5].imshow(patch)
			axs[i//5][i%5].set_axis_off()
			axs[i//5][i%5].text(0.1,0.1, "%s"%(class_list[filename_to_class_array_train[i_d,1]]), c="limegreen", va="top", fontsize=6)

	if(visual):
		plt.savefig("training_set_example.jpg", dpi=250)
		return

	return input_data, targets


def create_test_batch():

	for i in range(0,nb_im_test):
		path = "gz2_images_%dx%d/%d.jpg"%(image_size, image_size, filename_to_class_array_test[i,0])

		patch = np.asarray(Image.open(path))/255.0

		for depth in range(0,im_depth):
			input_test[i,depth*image_size*image_size:(depth+1)*image_size*image_size] = (np.copy(patch[:,:,depth]).flatten("C"))
		input_test[i,3*image_size*image_size:] = 0.0
		targets_test[i,:] = 0
		targets_test[i,filename_to_class_array_test[i,1]] = 1

	return input_test, targets_test


In [None]:
%%shell

cd /content/

python3 - <<EOF

#Will download the dataset at the fist call
from helper import *

create_train_batch(visual=1)

EOF

In [None]:
%cd /content/AI_astro_ED_AAIF/codes/CNN/classification/gz2_classification/

from PIL import Image
import matplotlib.pyplot as plt

im = Image.open("training_set_example.jpg")
plt.figure(figsize=(5,4), dpi=200)
plt.imshow(im)
plt.gca().axis('off')
plt.show()


#### Train the classifier

In [None]:
%%shell

cd /content/AI_astro_ED_AAIF/codes/CNN/classification/gz2_classification/

python3 - <<EOF


import time
import locale
import matplotlib.pyplot as plt
from scipy import signal
from threading import Thread

from helper import *
import numpy as np

#Comment to access system wide install
import sys, glob
sys.path.insert(0,glob.glob('/content/CIANNA/src/build/lib.*/')[-1])
import CIANNA as cnn


def i_ar(int_list):
	return np.array(int_list, dtype="int")

def f_ar(float_list):
	return np.array(float_list, dtype="float32")

def data_augm():

	data_augm, targets_augm = create_train_batch()
	cnn.delete_dataset("TRAIN_buf", silent=1)
	cnn.create_dataset("TRAIN_buf", nb_im_train, data_augm, targets_augm, silent=1)
	return


data_train, target_train = create_train_batch()
data_valid, target_valid = create_test_batch()
data_test, target_test = create_test_batch()

cnn.init(in_dim=i_ar([image_size,image_size]), in_nb_ch=im_depth+1, out_dim=nb_class, \
		bias=0.1, b_size=8, comp_meth="C_CUDA", dynamic_load=1, mixed_precision="FP16C_FP32A")

cnn.create_dataset("TRAIN", size=nb_im_train, input=data_train, target=target_train)
cnn.create_dataset("VALID", size=nb_im_test , input=data_valid, target=target_valid)

del(data_valid)

load_epoch = 0
if (len(sys.argv) > 1):
	load_epoch = int(sys.argv[1])
if(load_epoch > 0):
	cnn.load("net_save/net0_s%04d.dat"%load_epoch,load_epoch,0)
else:

	cnn.conv(f_size=i_ar([5,5]), nb_filters=16  , padding=i_ar([2,2]), activation="RELU")
	cnn.pool(p_size=i_ar([2,2]), p_type="MAX")
	cnn.norm(group_size=4, activation="LIN")

	cnn.conv(f_size=i_ar([3,3]), nb_filters=32  , padding=i_ar([2,2]), activation="RELU")
	cnn.pool(p_size=i_ar([2,2]), p_type="MAX")
	cnn.norm(group_size=4, activation="LIN")

	cnn.conv(f_size=i_ar([3,3]), nb_filters=64	, padding=i_ar([1,1]), activation="RELU")
	cnn.pool(p_size=i_ar([2,2]), p_type="MAX")
	cnn.norm(group_size=8, activation="LIN")

	cnn.conv(f_size=i_ar([3,3]), nb_filters=128 , padding=i_ar([1,1]), activation="RELU")
	cnn.conv(f_size=i_ar([1,1]), nb_filters=64  , padding=i_ar([0,0]), activation="RELU")
	cnn.conv(f_size=i_ar([3,3]), nb_filters=128 , padding=i_ar([1,1]), activation="RELU")
	cnn.pool(p_size=i_ar([2,2]), p_type="MAX")
	cnn.norm(group_size=16, activation="LIN")

	if(image_size == 128):
		cnn.conv(f_size=i_ar([3,3]), nb_filters=128 , padding=i_ar([1,1]), activation="RELU")
		cnn.conv(f_size=i_ar([1,1]), nb_filters=64  , padding=i_ar([0,0]), activation="RELU")
		cnn.conv(f_size=i_ar([3,3]), nb_filters=128 , padding=i_ar([1,1]), activation="RELU")
		cnn.pool(p_size=i_ar([2,2]), p_type="MAX")
		cnn.norm(group_size=16, activation="LIN")

	cnn.conv(f_size=i_ar([1,1]), nb_filters=nb_class , padding=i_ar([0,0]), activation="LIN")
	cnn.pool(p_size=i_ar([1,1]), p_type="AVG", p_global=1, activation="SMAX")


for i in range(load_epoch,1200):
	t = Thread(target=data_augm)
	t.start()

	cnn.train(nb_iter=1, learning_rate=0.002, end_learning_rate=0.0001, lr_decay=0.0012, momentum=0.9, weight_decay=0.001,
		confmat=1, control_interv=10, save_every=50, TC_scale_factor=1.0)

	t.join()
	cnn.swap_data_buffers("TRAIN")


EOF