# Create the model


In [None]:
import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms

import torch
from torchsummary import summary
from torchviz import make_dot


class ReconstructionNet(nn.Module):
    
    def __init__(self, reconstruction_matrix_size, learning_rate=0.001):
        super(ReconstructionNet, self).__init__()

        
        self.encoder = models.resnet18(pretrained=True)
        self.encoder.fc = nn.Identity()  # Remove the last fully-connected layer

        encoder_in_features = 512  

        # 3D Convolutional layer for reconstruction matrix 
        self.matrix_conv = nn.Conv3d(in_channels=1, out_channels=64, kernel_size=(3, 3, 3), padding=1)
        self.matrix_act = nn.ReLU(inplace=True)

        # Feature fusion layer with adjusted groups argument to handle 10 channels
        self.fusion_conv = nn.Conv2d(in_channels=8, out_channels=256, kernel_size=3, groups=8, padding=1)
        self.fusion_act = nn.ReLU(inplace=True)

        # Dropout for regularization 
        self.dropout = nn.Dropout(p=0.2)

        # Decoder 
        self.decoder_block1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            self.dropout,
            nn.ReLU(inplace=True)
        )
        self.decoder_block2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1),
            self.dropout,
            nn.ReLU(inplace=True)
        )
        self.decoder_block3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1),
            self.dropout,
            nn.ReLU(inplace=True)
        )
        self.decoder_out = nn.ConvTranspose2d(in_channels=32, out_channels=3, kernel_size=3, stride=2, padding=1)

        # Loss function (using MSE)
        self.criterion = nn.MSELoss()

        # Optimizer (adjust optimizer and learning rate)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)

        
        # Early stopping 
        self.patience = 5  # Number of epochs to wait for improvement
        self.min_loss = float('inf')
        self.counter = 0

    def forward(self, x, reconstruction_matrix):
        # Reshape reconstruction matrix for 3D convolutional layer (assuming channel dimension is 1)
        num_channels = reconstruction_matrix.shape
        reconstruction_matrix = reconstruction_matrix.unsqueeze(0)  # Add channel dimension
        
        num_channels = reconstruction_matrix.shape
        print(f" reconstruction_matrix shape: {num_channels}")
        
        # 3D Convolutional layer for reconstruction matrix
        encoded_matrix = self.matrix_conv(reconstruction_matrix)
        encoded_matrix = self.matrix_act(encoded_matrix)

        # Encoder
        encoder_features = self.encoder(x)

        # Flatten encoded matrix (adjust if spatial information is required)
        encoded_matrix = encoded_matrix.view(x.size(0), -1)  # Flatten for concatenation

        # Feature fusion
        fused_features = torch.cat((encoder_features, encoded_matrix), dim=1)
        fused_features = self.fusion_conv(fused_features)
        fused_features = self.fusion_act(fused_features)

        # Apply dropout for regularization
        fused_features = self.dropout(fused_features)

        # Decoder
        decoder_output = self.decoder_block1(fused_features)
        decoder_output = self.decoder_block2(decoder_output)
        decoder_output = self.decoder_block3(decoder_output)
        reconstructed_image = self.decoder_out(decoder_output)

        return reconstructed_image

    def train_step(self, x, reconstruction_matrix):
        # Forward pass
        reconstructed_image = self.forward(x, reconstruction_matrix)

        # Calculate loss (using MSE)
        loss = self.criterion(reconstructed_image, x)  

        # Backward pass and optimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    
    def validate_step(self, x, reconstruction_matrix):
    # Forward pass (no gradients needed)
        with torch.no_grad():
            reconstructed_image = self.forward(x, reconstruction_matrix)

        # Calculate loss (using MSE)
        loss = self.criterion(reconstructed_image, x)

        # Early stopping logic 
        if loss < self.min_loss:
            self.min_loss = loss
            self.counter = 0  # Reset counter for consecutive improvements
        else:
            self.counter += 1  # Increment counter for no improvement

        return loss.item()



In [None]:
import os
import scipy.io as sio  
from PIL import Image  
import torch
import torch.nn as nn
from torchvision import models


class ReconstructionDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.image_paths = []
        self.matrix_paths = []

        
        for subdir in os.listdir(data_dir):
            subdir_path = os.path.join(data_dir, subdir)
            if os.path.isdir(subdir_path):  # Check if it's a directory
                for f in os.listdir(subdir_path):
                    if f.endswith(".png"): 
                        image_path = os.path.join(subdir_path, f)
                        folder_name = os.path.basename(os.path.dirname(image_path))
                        matrix_path = os.path.join(subdir_path, folder_name + ".mat")
                        self.image_paths.append(image_path)
                        self.matrix_paths.append(matrix_path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        matrix_path = self.matrix_paths[idx]
        # Load image
        image = PIL.Image.open(image_path)

        # Apply transformations (including normalization)
        if self.transform:
            image = self.transform(image)
        else:
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
            image = transform(image)



        matrix = sio.loadmat(matrix_path)['rad']
        print(matrix.shape)
        matrix = matrix.flatten().reshape(-1, matrix.shape[1])
        matrix = torch.from_numpy(matrix).float()  # Convert to PyTorch tensor
        
        return image, matrix


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
import os
import PIL
import numpy as np


learning_rate = 0.1
batch_size =8
num_epochs = 10

train_data_dir = r"C:\Users\aggar\Downloads\datasets_reorganised_new\training"  
val_data_dir = r"C:\Users\aggar\Downloads\datasets_reorganised_new\validation"  
test_data_dir = r"C:\Users\aggar\Downloads\datasets_reorganised_new\testing"  
# Define data transformations (adjust as needed)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = ReconstructionDataset(train_data_dir, transform=transform)

print(f"Train Dataset Length: {len(train_dataset)}")
val_dataset = ReconstructionDataset(val_data_dir, transform=transform)
test_dataset = ReconstructionDataset(test_data_dir, transform=transform)



# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)    
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# # Create model instance
model = ReconstructionNet(reconstruction_matrix_size=train_dataset[0][1].shape[0])  


# Train the model

In [None]:

train_losses = []
val_losses = []

# Training loop
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    # Train step
    model.train()
    train_loss = 0.0
    for data in train_loader:
        images, reconstruction_matrices = data
        loss = model.train_step(images, reconstruction_matrices)
        train_loss += loss.item()  # Add .item() to get scalar value

    # Calculate average train loss for the epoch
    avg_train_loss = train_loss / len(train_loader)
    train_losses.append(avg_train_loss)  # Append to train_losses list

    # Validation step
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for data in val_loader:
            images, reconstruction_matrices = data
            loss = model.validate_step(images, reconstruction_matrices)
            val_loss += loss.item()  # Add .item() to get scalar value

    # Calculate average validation loss for the epoch
    avg_val_loss = val_loss / len(val_loader)
    val_losses.append(avg_val_loss)  # Append to val_losses list

    print(f"Train Loss: {avg_train_loss:.4f}")
    print(f"Val Loss: {avg_val_loss:.4f}")

# After training loop ends, evaluate on test set
model.eval()
test_loss = 0.0
with torch.no_grad():
    for data in test_loader:
        images, reconstruction_matrices = data
        loss = model.validate_step(images, reconstruction_matrices)
        test_loss += loss.item() 

avg_test_loss = test_loss / len(test_loader)
print(f"Test Loss: {avg_test_loss:.4f}")


## Plot Curves , Calculate RMSE Score

In [None]:

import torch
def calculate_rmse(predictions, targets):


  squared_errors = ((predictions - targets) ** 2).mean(dim=0)
  rmse = torch.sqrt(squared_errors)
  return rmse.item() 

test_predictions, test_targets = model.predict(test_data)  
rmse = calculate_rmse(test_predictions, test_targets)



print(f"RMSE on Test Set: {rmse:.4f}")


In [None]:
import matplotlib.pyplot as plt
def plot_loss_curves(train_losses, val_losses):
  
  plt.figure(figsize=(10, 6))
  plt.plot(train_losses, label='Training Loss')
  plt.plot(val_losses, label='Validation Loss')
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  plt.title('Training and Validation Loss Curves')
  plt.legend()
  plt.grid(True)
  plt.show()

 
    #assuming  train_losses and val_losses exsist
plot_loss_curves(train_losses, val_losses)   