In [13]:
import os
import pandas as pd

# Your patient CSV file
patient_csv_path = "patient.csv"

# Load patient metadata
patients = pd.read_csv(patient_csv_path)
print("Patient sheet loaded:", patients.shape)

# Folder containing medical images
base_folder = "Stroke_classification"

# Map folder names to labels
label_map = {
    "Haemorrhagic": "Haemorrhagic",
    "Ischemic": "Ischemic",
    "Normal": "Normal"
}

image_paths = []
labels = []

for folder in os.listdir(base_folder):
    full_path = os.path.join(base_folder, folder)
    
    if os.path.isdir(full_path):
        for img in os.listdir(full_path):
            if img.lower().endswith((".png", ".jpg", ".jpeg")):
                image_paths.append(os.path.join(full_path, img))
                labels.append(folder)

# Create DataFrame
meta_df = pd.DataFrame({
    "image_path": image_paths,
    "Stroke_Type": labels
})

print("Images mapped:", meta_df.shape)

# SAVE for Module 2 usage
meta_df.to_csv("meta_df.csv", index=False)

print("Saved meta_df.csv successfully!")
meta_df.head()


Patient sheet loaded: (300, 8)
Images mapped: (297, 2)
Saved meta_df.csv successfully!


Unnamed: 0,image_path,Stroke_Type
0,Stroke_classification\Haemorrhagic\Patient_003...,Haemorrhagic
1,Stroke_classification\Haemorrhagic\Patient_004...,Haemorrhagic
2,Stroke_classification\Haemorrhagic\Patient_012...,Haemorrhagic
3,Stroke_classification\Haemorrhagic\Patient_013...,Haemorrhagic
4,Stroke_classification\Haemorrhagic\Patient_028...,Haemorrhagic


In [14]:
meta_df = pd.read_csv("meta_df.csv")
print(meta_df.head())


                                          image_path   Stroke_Type
0  Stroke_classification\Haemorrhagic\Patient_003...  Haemorrhagic
1  Stroke_classification\Haemorrhagic\Patient_004...  Haemorrhagic
2  Stroke_classification\Haemorrhagic\Patient_012...  Haemorrhagic
3  Stroke_classification\Haemorrhagic\Patient_013...  Haemorrhagic
4  Stroke_classification\Haemorrhagic\Patient_028...  Haemorrhagic


In [15]:
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms


In [16]:
class EnhancerNet(nn.Module):
    def __init__(self):
        super(EnhancerNet, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 2, stride=2),
            nn.ReLU(),

            nn.ConvTranspose2d(32, 1, 2, stride=2),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

enhancer_model = EnhancerNet()
print("Enhancer model ready.")


Enhancer model ready.


In [17]:
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])


In [23]:
def enhance_image(img_path):
    img = Image.open(img_path)
    img_tensor = transform(img).unsqueeze(0)

    with torch.no_grad():
        enhanced_tensor = enhancer_model(img_tensor)

    enhanced_np = enhanced_tensor.squeeze().numpy()

    return enhanced_np


In [24]:
sample_path = meta_df.iloc[0]["image_path"]
print("Testing enhancement on:", sample_path)

enhance_image(sample_path)


Testing enhancement on: Stroke_classification\Haemorrhagic\Patient_003.png


array([[0.03331919, 0.03562793, 0.03521232, ..., 0.032878  , 0.03395087,
        0.03207485],
       [0.03430567, 0.03643325, 0.03536519, ..., 0.03072106, 0.0320534 ,
        0.02790923],
       [0.03398917, 0.03642219, 0.03472058, ..., 0.0314109 , 0.03017635,
        0.02963807],
       ...,
       [0.03073784, 0.03140718, 0.03311124, ..., 0.02952928, 0.0321618 ,
        0.02709967],
       [0.02985315, 0.03209171, 0.03093318, ..., 0.0309866 , 0.02966819,
        0.02900667],
       [0.02977415, 0.02828303, 0.02988974, ..., 0.02625975, 0.02831389,
        0.02745545]], dtype=float32)

In [21]:
optimizer = torch.optim.Adam(enhancer_model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()

def train_enhancement(epochs=3):
    enhancer_model.train()

    for e in range(epochs):
        total_loss = 0

        for idx, row in meta_df.iterrows():
            img_path = row["image_path"]

            try:
                img = Image.open(img_path)
                img = transform(img).unsqueeze(0)

                noise = torch.randn_like(img) * 0.1
                noisy_img = img + noise

                out = enhancer_model(noisy_img)
                loss = loss_fn(out, img)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            except:
                continue

        print(f"Epoch {e+1}/{epochs} - Loss: {total_loss:.4f}")

train_enhancement(epochs=5)


Epoch 1/5 - Loss: 3.7779
Epoch 2/5 - Loss: 0.6465
Epoch 3/5 - Loss: 0.5709
Epoch 4/5 - Loss: 0.5280
Epoch 5/5 - Loss: 0.5025


In [25]:
import os
os.makedirs("Enhanced_Output", exist_ok=True)

def save_all_enhanced():
    for idx, row in meta_df.iterrows():
        img_path = row["image_path"]
        enhanced = enhance_image(img_path)

        # Convert float â†’ uint8 (0-255)
        enhanced_img = (enhanced * 255).astype("uint8")

        out_path = f"Enhanced_Output/enh_{idx}.png"
        cv2.imwrite(out_path, enhanced_img)

    print("All enhanced images saved in Enhanced_Output/")


In [26]:
save_all_enhanced()


All enhanced images saved in Enhanced_Output/
