In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import h5py
from torchvision import transforms
from tqdm import tqdm
import numpy as np

# Check for CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Photometric Feature Extractor (MLP)
class PhotometricMLP(nn.Module):
    def __init__(self, input_size):
        super(PhotometricMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.bn1 = nn.BatchNorm1d(128)
        self.fc2 = nn.Linear(128, 128)
        self.bn2 = nn.BatchNorm1d(128)
        self.dropout = nn.Dropout(0.3)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.bn1(self.fc1(x)))
        x = self.dropout(self.relu(self.bn2(self.fc2(x))))
        return x

# Custom CNN for Image Feature Extraction (Improved with Adaptive Pooling and Extra Layer)
class ImageCNN(nn.Module):
    def __init__(self):
        super(ImageCNN, self).__init__()
        self.conv1 = nn.Conv2d(5, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)

        self.fc = nn.Linear(256 * 4 * 4, 128)
    
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = self.pool(self.relu(self.conv4(x)))
        x = self.adaptive_pool(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(self.relu(self.fc(x)))
        return x

# Hybrid Model
class HybridRedshiftModel(nn.Module):
    def __init__(self, photometric_input):
        super(HybridRedshiftModel, self).__init__()
        self.photo = PhotometricMLP(photometric_input).to(device)
        self.image = ImageCNN().to(device)
        self.fc = nn.Linear(128 * 2, 1)  # Combining photometric and image features
    
    def forward(self, photo, img):
        photo_feat = self.photo(photo)
        img_feat = self.image(img)
        fused = torch.cat((photo_feat, img_feat), dim=1)
        output = self.fc(fused).squeeze(1)
        return output

# Custom Dataset for HDF5 File
class HDF5Dataset(data.Dataset):
    def __init__(self, hdf5_path):
        self.hdf5_path = hdf5_path
        with h5py.File(hdf5_path, 'r') as f:
            self.length = len(f['image'])
        self.transform = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
        ])
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        with h5py.File(self.hdf5_path, 'r') as f:
            img = torch.tensor(f['image'][idx], dtype=torch.float32)
            photo = torch.tensor([
                f['g_cmodel_mag'][idx], f['r_cmodel_mag'][idx], f['i_cmodel_mag'][idx],
                f['z_cmodel_mag'][idx], f['y_cmodel_mag'][idx],
                f['g_ellipticity'][idx], f['r_ellipticity'][idx], f['i_ellipticity'][idx],
                f['z_ellipticity'][idx], f['y_ellipticity'][idx],
                f['g_sersic_index'][idx], f['r_sersic_index'][idx], f['i_sersic_index'][idx],
                f['z_sersic_index'][idx], f['y_sersic_index'][idx],
            ], dtype=torch.float32)
            redshift = torch.tensor(f['specz_redshift'][idx], dtype=torch.float32)
        return photo, img, redshift

# Training function with Early Stopping and Learning Rate Scheduler
def train_model(model, dataloader, optimizer, criterion, num_epochs=10, patience=3, save_path='24_mar_model.pth'):
    model.train()
    best_loss = np.inf
    early_stop_counter = 0
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        with tqdm(dataloader, unit="batch") as tepoch:
            for photo, img, labels in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}")
                photo, img, labels = photo.to(device), img.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(photo, img)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
                tepoch.set_postfix(loss=loss.item())
        
        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch+1} Loss: {avg_loss}")
        scheduler.step(avg_loss)
        
        if avg_loss < best_loss:
            best_loss = avg_loss
            early_stop_counter = 0
            torch.save(model.state_dict(), save_path)
            print(f"Model saved at epoch {epoch+1} with loss {best_loss}")
        else:
            early_stop_counter += 1
            if early_stop_counter >= patience:
                print("Early stopping triggered.")
                break

# Instantiate model and optimizer
model = HybridRedshiftModel(photometric_input=15).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.HuberLoss()

# Load Dataset and Dataloader
dataset = HDF5Dataset("D:/Galaxy Datasets/temp_training.hdf5")
dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True)

# Train Model
train_model(model, dataloader, optimizer, criterion, num_epochs=10, patience=3, save_path='24_mar_model.pth')

# Load Best Model for Testing
model.load_state_dict(torch.load('24_mar_model.pth'))
model.eval()

# Evaluate Model
mse_loss = 0.0
absolute_error = 0.0
num_samples = len(dataset)
with torch.no_grad():
    for photo, img, label in dataloader:
        photo, img, label = photo.to(device), img.to(device), label.to(device)
        predictions = model(photo, img)
        mse_loss += nn.MSELoss()(predictions, label).item()
        absolute_error += torch.abs(predictions - label).sum().item()

mse_loss /= num_samples
mae_loss = absolute_error / num_samples
print(f"Mean Squared Error (MSE): {mse_loss}")
print(f"Mean Absolute Error (MAE): {mae_loss}")


Epoch 1: 100%|██████████| 6393/6393 [11:05<00:00,  9.60batch/s, loss=0.0578] 


Epoch 1 Loss: 0.04116680618957088
Model saved at epoch 1 with loss 0.04116680618957088


Epoch 2: 100%|██████████| 6393/6393 [12:15<00:00,  8.69batch/s, loss=0.00432]  


Epoch 2 Loss: 0.032233147479178714
Model saved at epoch 2 with loss 0.032233147479178714


Epoch 3: 100%|██████████| 6393/6393 [11:28<00:00,  9.29batch/s, loss=0.0118] 


Epoch 3 Loss: 0.02999175899105877
Model saved at epoch 3 with loss 0.02999175899105877


Epoch 4: 100%|██████████| 6393/6393 [16:07<00:00,  6.61batch/s, loss=0.0237]  


Epoch 4 Loss: 0.028998501972536735
Model saved at epoch 4 with loss 0.028998501972536735


Epoch 5: 100%|██████████| 6393/6393 [11:43<00:00,  9.09batch/s, loss=0.0078] 


Epoch 5 Loss: 0.027572316534243392
Model saved at epoch 5 with loss 0.027572316534243392


Epoch 6: 100%|██████████| 6393/6393 [10:35<00:00, 10.05batch/s, loss=0.00863]


Epoch 6 Loss: 0.02733594663277657
Model saved at epoch 6 with loss 0.02733594663277657


Epoch 7: 100%|██████████| 6393/6393 [10:03<00:00, 10.59batch/s, loss=0.0382] 


Epoch 7 Loss: 0.028114734391300755


Epoch 8: 100%|██████████| 6393/6393 [09:22<00:00, 11.36batch/s, loss=0.0145] 


Epoch 8 Loss: 0.027771796453366866


Epoch 9: 100%|██████████| 6393/6393 [09:24<00:00, 11.32batch/s, loss=0.0144] 


Epoch 9 Loss: 0.026899618002577765
Model saved at epoch 9 with loss 0.026899618002577765


Epoch 10: 100%|██████████| 6393/6393 [09:15<00:00, 11.51batch/s, loss=0.0315] 


Epoch 10 Loss: 0.026001005889127494
Model saved at epoch 10 with loss 0.026001005889127494


RuntimeError: Error(s) in loading state_dict for HybridRedshiftModel:
	Missing key(s) in state_dict: "photo.bn1.weight", "photo.bn1.bias", "photo.bn1.running_mean", "photo.bn1.running_var", "photo.bn2.weight", "photo.bn2.bias", "photo.bn2.running_mean", "photo.bn2.running_var", "image.conv4.weight", "image.conv4.bias". 
	size mismatch for image.fc.weight: copying a param with shape torch.Size([128, 8192]) from checkpoint, the shape in current model is torch.Size([128, 4096]).

In [4]:
model.load_state_dict(torch.load('24_mar_model.pth'))
model.eval()

# Evaluate Model
mse_loss = 0.0
absolute_error = 0.0
num_samples = len(dataset)
with torch.no_grad():
    for photo, img, label in dataloader:
        photo, img, label = photo.to(device), img.to(device), label.to(device)
        predictions = model(photo, img)
        mse_loss += nn.MSELoss()(predictions, label).item()
        absolute_error += torch.abs(predictions - label).sum().item()

mse_loss /= num_samples
mae_loss = absolute_error / num_samples
print(f"Mean Squared Error (MSE): {mse_loss}")
print(f"Mean Absolute Error (MAE): {mae_loss}")

Mean Squared Error (MSE): 0.007328542011662434
Mean Absolute Error (MAE): 0.10066949397290989


In [5]:
percentage_accuracy = 100 * (1 - mae_loss / torch.mean(torch.tensor(dataset[:][2])).item())
print(f"Percentage Accuracy: {percentage_accuracy}%")

  photo = torch.tensor([


Percentage Accuracy: 83.08095812483602%


  percentage_accuracy = 100 * (1 - mae_loss / torch.mean(torch.tensor(dataset[:][2])).item())


In [6]:
for i in range (0,5):
    print(f"Original value: {label[i]}      Predicted value: {predictions[i]}")

Original value: 0.16672000288963318      Predicted value: 0.21281550824642181
Original value: 0.6856560111045837      Predicted value: 0.5945178270339966
Original value: 0.1525000035762787      Predicted value: 0.49544042348861694
Original value: 2.0625600814819336      Predicted value: 2.1511037349700928
Original value: 0.8759999871253967      Predicted value: 1.2244274616241455
