<a href="https://colab.research.google.com/github/Usool-Data-Science/500-AI-Machine-learning-Deep-learning-Computer-vision-NLP-Projects-with-code/blob/main/Feature_Unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [18]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import torchvision

import os
import numpy as np
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.applications.vgg16 import preprocess_input, VGG16
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, UpSampling2D, Concatenate
from google.colab import files

In [19]:
# Define the convolutional layer
conv_layer = nn.Sequential(
    #take in an input image with 3 RGB color channels and apply 16 filters to produce 16 output feature maps. 
    #Each filter is a 3x3 matrix of weights that is convolved with the input image to extract features
    nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
    nn.BatchNorm2d(16),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2)
)


In [20]:
# Define the self-attention layer
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()

        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        
        # Project the inputs to query, key, and value
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)

        # Compute the attention scores
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)

        # Apply the attention to the value
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, channels, height, width)

        # Apply the scaling factor and add to the input
        out = self.gamma * out + x

        return out

In [21]:
# Define the image transformer
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [22]:
# Instantiate the self-attention layer
self_att_layer = SelfAttention(32)

# Define the directory path and loop through the images
# directory_path = '/content/drive/MyDrive/imageFusion'
# output_path = '/content/drive/MyDrive/imageFusion/FusionOutput'

images_to_fuse = ['/content/drive/MyDrive/imageFusion/Source3/lytro-01-A.jpg',
                  '/content/drive/MyDrive/imageFusion/Source3/lytro-01-B-New.jpg',
                  '/content/drive/MyDrive/imageFusion/Source3/lytro-01-C-New.jpg']


def extract_feature_maps(images_to_fuse):
  images_fused = []
  for filename in images_to_fuse:
      if filename.endswith('.jpg') or filename.endswith('.jpeg') or filename.endswith('.png'):
          # Load the image
          # image = Image.open(os.path.join(directory_path, filename))
          image = Image.open(filename)
          # Transform the image
          image = transform(image).unsqueeze(0)
          # Extract the features using the convolutional layer
          features = conv_layer(image)
          # Apply the self-attention layer to the features
          spatial_features = self_att_layer(features)
          # Save the spatial features to disk
          # output_filename = os.path.splitext(filename)[0] + '.pt'
          # torch.save(spatial_features, os.path.join(output_path, output_filename))
          images_fused.append(spatial_features)
  return(images_fused)


In [14]:
import torch

# Define the U-Net architecture
class UNet(torch.nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # Encoder
        self.conv1 = torch.nn.Conv2d(32, 64, 3, stride=1, padding=1)
        self.conv2 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1)
        self.conv3 = torch.nn.Conv2d(128, 256, 3, stride=1, padding=1)
        self.conv4 = torch.nn.Conv2d(256, 512, 3, stride=1, padding=1)

        # Decoder
        self.upconv4 = torch.nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv5 = torch.nn.Conv2d(512, 256, 3, stride=1, padding=1)
        self.upconv3 = torch.nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv6 = torch.nn.Conv2d(256, 128, 3, stride=1, padding=1)
        self.upconv2 = torch.nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv7 = torch.nn.Conv2d(128, 64, 3, stride=1, padding=1)
        self.upconv1 = torch.nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.conv8 = torch.nn.Conv2d(64, 32, 3, stride=1, padding=1)

        self.final_conv = torch.nn.Conv2d(32, 3, 3, stride=1, padding=1)

    def forward(self, x):
        # Encoder
        x1 = torch.nn.functional.relu(self.conv1(x))
        x2 = torch.nn.functional.relu(self.conv2(x1))
        x3 = torch.nn.functional.relu(self.conv3(x2))
        x4 = torch.nn.functional.relu(self.conv4(x3))

        # Decoder
        y = torch.nn.functional.relu(self.upconv4(x4))
        y = torch.cat([y, x3], dim=1)
        y = torch.nn.functional.relu(self.conv5(y))

        y = torch.nn.functional.relu(self.upconv3(y))
        y = torch.cat([y, x2], dim=1)
        y = torch.nn.functional.relu(self.conv6(y))

        y = torch.nn.functional.relu(self.upconv2(y))
        y = torch.cat([y, x1], dim=1)
        y = torch.nn.functional.relu(self.conv7(y))

        y = torch.nn.functional.relu(self.upconv1(y))
        y = torch.cat([y, x], dim=1)
        y = torch.nn.functional.relu(self.conv8(y))

        y = torch.nn.functional.tanh(self.final_conv(y))

        return y


In [23]:
# Define the U-Net architecture
class UNet(torch.nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # Encoder
        self.conv1 = torch.nn.Conv2d(32, 64, 3, stride=1, padding=1)
        self.conv2 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1)
        self.conv3 = torch.nn.Conv2d(128, 256, 3, stride=1, padding=1)
        self.conv4 = torch.nn.Conv2d(256, 512, 3, stride=1, padding=1)

        # Decoder
        self.upconv4 = torch.nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv5 = torch.nn.Conv2d(256, 256, 3, stride=1, padding=1)
        self.upconv3 = torch.nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv6 = torch.nn.Conv2d(128, 128, 3, stride=1, padding=1)
        self.upconv2 = torch.nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv7 = torch.nn.Conv2d(64, 64, 3, stride=1, padding=1)
        self.conv8 = torch.nn.Conv2d(64, 3, 3, stride=1, padding=1)

        self.relu = torch.nn.ReLU()

    def forward(self, x):
        # Encoder
        x1 = self.relu(self.conv1(x))
        x2 = self.relu(self.conv2(x1))
        x3 = self.relu(self.conv3(x2))
        x4 = self.relu(self.conv4(x3))

        # Decoder
        x = self.relu(self.upconv4(x4))
        x = torch.cat((x, x3), dim=1)
        x = self.relu(self.conv5(x))
        x = self.relu(self.upconv3(x))
        x = torch.cat((x, x2), dim=1)
        x = self.relu(self.conv6(x))
        x = self.relu(self.upconv2(x))
        x = torch.cat((x, x1), dim=1)
        x = self.relu(self.conv7(x))
        x = self.conv8(x)

        return x


In [24]:
# Load the feature maps
images_to_fuse = ['/content/drive/MyDrive/imageFusion/Source3/lytro-01-A.jpg',
                  '/content/drive/MyDrive/imageFusion/Source3/lytro-01-B-New.jpg',
                  '/content/drive/MyDrive/imageFusion/Source3/lytro-01-C-New.jpg']

feature_maps = extract_feature_maps(images_to_fuse)

# Concatenate the feature maps
x = torch.cat(feature_maps, dim=1)

# Instantiate the U-Net model
model = UNet()

# Generate the fused image
fused_image = model(x)

# Save the fused image
torchvision.utils.save_image(fused_image, 'fused_image.jpg')

RuntimeError: ignored