<a href="https://colab.research.google.com/github/Polisetty-Cyril/Astronomical-Image-Denoising-and-Enhancement/blob/main/Astronomical_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
from sklearn.feature_extraction.image import extract_patches_2d
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [3]:
hdul = fits.open('/content/hlsp_heritage_hst_acs-wfc_m51_f555w_v1_drz_sci.fits')
image_data = hdul[0].data
hdul.close()
image_data = image_data / np.max(image_data)

In [15]:
def crop_center(img, cropx=512, cropy=512):
  y, x = img.shape
  startx = x // 2 - cropx // 2
  starty = y // 2 - cropy // 2
  return img[starty:starty+cropy, startx:startx+cropx]


image_data = crop_center(image_data)

In [16]:
noise_sigma = 0.1
noisy_image = image_data + np.random.normal(0, noise_sigma, image_data.shape)
patches = extract_patches_2d(noisy_image, (64, 64), max_patches=1000)

In [17]:
def mask_patches(patches, mask_fraction=0.05):
  masked = np.copy(patches)
  masks = np.zeros_like(patches)


  for i in range(len(patches)):
    total = patches[i].size
    n_mask = int(total * mask_fraction)
    coords = np.unravel_index(np.random.choice(total, n_mask, replace=False), patches[i].shape)
    masked[i][coords] = 0
    masks[i][coords] = 1


  return masked, patches, masks


masked_patches, targets, masks = mask_patches(patches)

In [21]:
class DenoiseDataset(Dataset):
  def __init__(self, X, Y, M):
    self.X = torch.tensor(X[:, None, ...], dtype = torch.float32)
    self.Y = torch.tensor(Y[:, None, ...], dtype = torch.float32)
    self.M = torch.tensor(M[:, None, ...], dtype = torch.float32)

  def __len__(self):
    return len(self.X)

  def __getitem__(self, idx):
    return self.X[idx], self.Y[idx], self.M[idx]

dataset = DenoiseDataset(masked_patches, targets, masks)
dataloader = DataLoader(dataset, batch_size = 32, shuffle = True)

In [22]:
class UNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.enc1 = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1), nn.ReLU())
    self.pool1 = nn.MaxPool2D(2)
    self.enc2 = nn.Sequential(nn.Conv2d(32, 64, 3, padding=1), nn.ReLU())
    self.pool2 = nn.MaxPool2D(2)


    self.bottleneck = nn.Sequential(nn.Conv2d(64, 128, 3, padding=1), nn.ReLU())


    self.up1 = nn.Upsample(scale_factor=2, mode='nearest')
    self.dec1 = nn.Sequential(nn.Conv2d(128, 64, 3, padding=1), nn.ReLU())
    self.up2 = nn.Upsample(scale_factor=2, mode='nearest')
    self.dec2 = nn.Sequential(nn.Conv2d(64, 32, 3, padding=1), nn.ReLU())
    self.out = nn.Conv2d(32, 1, 1)


  def forward(self, x):
    x1 = self.enc1(x)
    x2 = self.enc2(self.pool1(x1))
    x3 = self.bottleneck(self.pool2(x2))
    x = self.up1(x3)
    x = self.dec1(x)
    x = self.up2(x)
    x = self.dec2(x)
    return self.out(x)

In [23]:
model = UNet()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


for epoch in range(5):
  model.train()
  total_loss = 0
  for x, y, m in dataloader:
    x, y, m = x.to(device), y.to(device), m.to(device)
    pred = model(x)
    loss = F.mse_loss(pred * m, y * m)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    total_loss += loss.item()


  print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")

AttributeError: module 'torch.nn' has no attribute 'MaxPool2D'