In [2]:
import os
import numpy as np
from PIL import Image
import torch
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

In [6]:
## Building the dataset
# RUN ONLY ONCE
import os
import shutil
import random

def split_data(class_name, split_ratio=0.8):
    src_dir = f'./{class_name}'
    if not os.path.exists(src_dir):
        print(f"Source folder {src_dir} does not exist!")
        return

    # Include all files (ignore extensions)
    files = [f for f in os.listdir(src_dir) if os.path.isfile(os.path.join(src_dir, f))]
    print(f"Found {len(files)} files in {class_name}")

    if len(files) == 0:
        print(f"No files found in {src_dir}")
        return

    random.shuffle(files)
    split_idx = int(len(files) * split_ratio)
    train_files = files[:split_idx]
    val_files = files[split_idx:]

    for split, split_files in [('train', train_files), ('val', val_files)]:
        dst_dir = os.path.join('dataset', split, class_name)
        os.makedirs(dst_dir, exist_ok=True)
        for file in split_files:
            src_file = os.path.join(src_dir, file)
            dst_file = os.path.join(dst_dir, file)
            print(f"Copying {src_file} to {dst_file}")
            shutil.copy(src_file, dst_file)

split_data('demented')
split_data('non-demented')


Found 86 files in demented
Copying ./demented\OAS2_0164_MR1 to dataset\train\demented\OAS2_0164_MR1
Copying ./demented\OAS2_0106_MR1 to dataset\train\demented\OAS2_0106_MR1
Copying ./demented\OAS2_0184_MR2 to dataset\train\demented\OAS2_0184_MR2
Copying ./demented\OAS2_0182_MR2 to dataset\train\demented\OAS2_0182_MR2
Copying ./demented\OAS2_0159_MR1 to dataset\train\demented\OAS2_0159_MR1
Copying ./demented\OAS2_0108_MR2 to dataset\train\demented\OAS2_0108_MR2
Copying ./demented\OAS2_0113_MR2 to dataset\train\demented\OAS2_0113_MR2
Copying ./demented\OAS2_0102_MR1 to dataset\train\demented\OAS2_0102_MR1
Copying ./demented\OAS2_0179_MR2 to dataset\train\demented\OAS2_0179_MR2
Copying ./demented\OAS2_0102_MR2 to dataset\train\demented\OAS2_0102_MR2
Copying ./demented\OAS2_0150_MR1 to dataset\train\demented\OAS2_0150_MR1
Copying ./demented\OAS2_0185_MR1 to dataset\train\demented\OAS2_0185_MR1
Copying ./demented\OAS2_0157_MR1 to dataset\train\demented\OAS2_0157_MR1
Copying ./demented\OAS2_

In [7]:
import os
import numpy as np
from PIL import Image
from torchvision import transforms
from tqdm import tqdm

augment = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
    transforms.ColorJitter(brightness=0.1, contrast=0.1)
])

root_dirs = [
    'dataset/train/demented',
    'dataset/train/non-demented'
]

num_augmented_per_file = 4

for root in root_dirs:
    files = os.listdir(root)
    print(f"Augmenting {len(files)} files in {root}...")

    for fname in tqdm(files, desc=f"Processing {os.path.basename(root)}"):
        fpath = os.path.join(root, fname)
        try:
            volume = np.load(fpath)
        except Exception as e:
            print(f"Skipping {fpath}: {e}")
            continue

        if volume.ndim != 3:
            print(f"Skipping {fpath}: Not a 3D volume")
            continue

        if volume.shape[0] != 20:
            volume = np.transpose(volume, (2, 0, 1))

        for i in range(num_augmented_per_file):
            augmented_slices = []
            for slice_idx in range(volume.shape[0]):
                slice_img = volume[slice_idx]
                pil_img = Image.fromarray(np.uint8(slice_img)).convert("RGB")
                aug_img = augment(pil_img)
                aug_arr = np.array(aug_img.convert("L"))
                augmented_slices.append(aug_arr)

            augmented_volume = np.stack(augmented_slices, axis=0)
            out_name = fname + f"_aug{i+1}"
            if out_name.endswith('.npy'):
                out_name = out_name[:-4]
            out_path = os.path.join(root, out_name)
            with open(out_path, 'wb') as f:
                np.save(f, augmented_volume)


Augmenting 68 files in dataset/train/demented...


Processing demented: 100%|██████████| 68/68 [00:30<00:00,  2.24it/s]


Augmenting 62 files in dataset/train/non-demented...


Processing non-demented: 100%|██████████| 62/62 [00:26<00:00,  2.31it/s]


In [3]:
train_root = "dataset/train"
val_root = "dataset/val"

train_samples = []
val_samples = []

for label_idx, class_name in enumerate(sorted(os.listdir(train_root))):
    class_path = os.path.join(train_root, class_name)
    for fname in os.listdir(class_path):
        train_samples.append((os.path.join(class_path, fname), label_idx))

for label_idx, class_name in enumerate(sorted(os.listdir(val_root))):
    class_path = os.path.join(val_root, class_name)
    for fname in os.listdir(class_path):
        val_samples.append((os.path.join(class_path, fname), label_idx))



In [4]:
# Define data transformations for data augmentation and normalization
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=5),
    transforms.ToTensor(),  # Converts (H, W) to (1, H, W)
    transforms.Normalize([0.5]*3, [0.5]*3)
])

transform_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])


In [5]:
test_path = train_samples[0][0]
test_array = np.load(test_path)

print("Shape:", test_array.shape)
print("Dtype:", test_array.dtype)
print("Min:", test_array.min(), "Max:", test_array.max())

Shape: (20, 256, 115)
Dtype: float64
Min: -939.9154701692341 Max: 1491.7522524275992


In [6]:
import numpy as np
import torch
from PIL import Image

def load_sample(path, transform=None):
    volume = np.load(path)  # Shape: (slices, height, width)

    # Get the center slice index
    center = volume.shape[0] // 2
    start = max(center - 10, 0)
    end = min(center + 11, volume.shape[0])

    # Pad if necessary
    if end - start < 21:
        padding_needed = 21 - (end - start)
        pad_before = padding_needed // 2
        pad_after = padding_needed - pad_before
        volume = np.pad(volume, ((pad_before, pad_after), (0, 0), (0, 0)), mode='constant')
        # Update indices after padding
        start = center - 10
        end = center + 11

    # Extract 21 slices
    slices = volume[start:end]  # Shape: (21, H, W)

    # Normalize to [0, 1]
    slices = (slices - slices.min()) / (slices.max() - slices.min() + 1e-8)

    processed = []
    for s in slices:
        img = Image.fromarray((s * 255).astype(np.uint8)).convert('RGB')  # Shape: (H, W, 3)
        if transform:
            img = transform(img)  # Shape: (3, H, W)
        processed.append(img)

    return torch.stack(processed)  # Shape: (21, 3, H, W)


In [15]:
from tqdm import tqdm

def process_sample(p, label, transform):
    tensor = load_sample(p, transform)  # (21, 3, 224, 224)
    return tensor, label

def collate_fn(batch):
    inputs, labels = zip(*[process_sample(p, label, transform_train) for p, label in batch])
    return torch.stack(inputs), torch.tensor(labels)

train_loader = torch.utils.data.DataLoader(train_samples, batch_size=1, shuffle=True, collate_fn=collate_fn)

def collate_fn_val(batch):
    inputs, labels = zip(*[process_sample(p, label, transform_val) for p, label in batch])
    return torch.stack(inputs), torch.tensor(labels)

val_loader = torch.utils.data.DataLoader(val_samples, batch_size=1, shuffle=False, collate_fn=collate_fn_val)



In [16]:
batch = next(iter(train_loader))
print(batch[0].shape)

torch.Size([1, 21, 3, 224, 224])


In [2]:
print("hello worlds")

hello worlds


In [8]:
# ⬇️ Import 2D ResNet50 instead of 3D ResNet
from torchvision.models import resnet50
import torch.nn as nn
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained 2D ResNet50
model = resnet50(pretrained=True)

# Replace final layer with Dropout + FC
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(p=0.5),
    nn.Linear(num_ftrs, 2)
)

# Freeze all layers except FC
for name, param in model.named_parameters():
    param.requires_grad = name.startswith("fc")

model = model.to(device)
print("Using device:", device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-4,
    weight_decay=1e-5
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)




Using device: cuda


In [12]:
def print_progress(current, total, prefix="", length=30):
    percent = f"{100 * (current / total):.1f}"
    filled = int(length * current // total)
    bar = "█" * filled + '-' * (length - filled)
    print(f"\r{prefix} |{bar}| {percent}% ({current}/{total})", end='\r')
    if current == total:
        print()  # Newline after complete


In [14]:
num_epochs = 10

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("=" * 5)

    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()
            loader = train_loader
        else:
            model.eval()
            loader = val_loader

        running_loss = 0.0
        running_corrects = 0

        for i, (inputs, labels) in enumerate(loader):
            inputs = inputs.view(-1, 3, 224, 224).to(device)  # (B*21, 3, 224, 224)
            labels = labels.unsqueeze(1).repeat(1, 21).view(-1).to(device)  # (B*21,)

            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()


            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

            # 👇 Show lightweight progress bar per epoch
            print_progress(i + 1, len(loader), prefix=f"{phase.capitalize()}")

        epoch_loss = running_loss / len(loader.dataset)
        epoch_acc = running_corrects.double() / len(loader.dataset)

        print(f"{phase.capitalize()} Loss: {epoch_loss:.4f} | Accuracy: {epoch_acc:.4f}")

    scheduler.step()



Epoch 1/10
=====
Train |██████████████████████████████| 100.0% (650/650)
Train Loss: 15.6432 | Accuracy: 10.2800
Val |██████████████████████████████| 100.0% (34/34)
Val Loss: nan | Accuracy: 11.1176

Epoch 2/10
=====
Train |██████████████████████████████| 100.0% (650/650)
Train Loss: 15.4275 | Accuracy: 10.3831
Val |██████████████████████████████| 100.0% (34/34)
Val Loss: nan | Accuracy: 11.1176

Epoch 3/10
=====
Train |█████-------------------------| 19.1% (124/650)

KeyboardInterrupt: 

In [None]:
torch.cuda.empty_cache()


In [49]:
# Save the model
torch.save(model.state_dict(), 'flower_classification_model.pth')

In [42]:
import os
import concurrent.futures
from PIL import Image

dataset_path = "dataset/train"  # Change this to your dataset path
num_threads = os.cpu_count()  # Use all CPU cores

def check_and_remove_image(file_path):
    """Verifies image and removes it if corrupted."""
    try:
        with Image.open(file_path) as img:
            img.verify()  # Verify without fully loading
    except Exception:
        print(f"Removing corrupted image: {file_path}")
        os.remove(file_path)  # Delete invalid image

# Get all image file paths from subdirectories
image_files = []
for root, _, files in os.walk(dataset_path):
    for file in files:
        image_files.append(os.path.join(root, file))

# Run in parallel for speed
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
    executor.map(check_and_remove_image, image_files)

print("✅ Image cleanup complete!")


Removing corrupted image: dataset/train\daisy\4534460263_8e9611db3c_n_jpg.rf.2756de378a7a041406b7aa661912da99.jpg
✅ Image cleanup complete!


In [46]:
# Inspect one batch
inputs, labels = next(iter(train_loader))
print("Input shape:", inputs.shape)


Input shape: torch.Size([16, 3, 224, 224])


In [43]:
del image_datasets, dataloaders  # Clear previous dataset objects

## Classification on unseen image

In [51]:
import torch
from torchvision import models, transforms
from PIL import Image

# Load the saved model
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 1000) 
model.load_state_dict(torch.load("flower_classification_model.pth"))
model.eval()

# Create a new model with the correct final layer
new_model = models.resnet18(pretrained=True)
new_model.fc = nn.Linear(new_model.fc.in_features, 2)

# Copy the weights and biases from the loaded model to the new model
new_model.fc.weight.data = model.fc.weight.data[0:2] # Copy only the first 2 output units
new_model.fc.bias.data = model.fc.bias.data[0:2]



In [53]:
# load and preprocess unseen image
image_path = 'test.jpg'
image = Image.open(image_path)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0) # Add a batch dimension

In [54]:
# Perform inference
with torch.no_grad():
    output = model(input_batch)

# Get the predicted class
_, predicted_class = output.max(1)

# Map the predicted class to the class name
class_names = ['daisy', 'dandelion']
predicted_class_name = class_names[predicted_class.item()]

print(f'The predicted class is: {predicted_class_name}')

The predicted class is: daisy


In [3]:
import numpy as np

file_path = "/home/omanaokar/Desktop/Alzeimers-detection/demented/OAS2_0102_MR1"

file = np.load(file_path)
file.shape

(20, 256, 115)