In [2]:
import os
import numpy as np
from PIL import Image
import torch
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

In [5]:
## Building the dataset

import os
import shutil
import random

def split_data(class_name, split_ratio=0.8):
    src_dir = f'./{class_name}'
    if not os.path.exists(src_dir):
        print(f"Source folder {src_dir} does not exist!")
        return

    # Include all files (ignore extensions)
    files = [f for f in os.listdir(src_dir) if os.path.isfile(os.path.join(src_dir, f))]
    print(f"Found {len(files)} files in {class_name}")

    if len(files) == 0:
        print(f"No files found in {src_dir}")
        return

    random.shuffle(files)
    split_idx = int(len(files) * split_ratio)
    train_files = files[:split_idx]
    val_files = files[split_idx:]

    for split, split_files in [('train', train_files), ('val', val_files)]:
        dst_dir = os.path.join('dataset', split, class_name)
        os.makedirs(dst_dir, exist_ok=True)
        for file in split_files:
            src_file = os.path.join(src_dir, file)
            dst_file = os.path.join(dst_dir, file)
            print(f"Copying {src_file} to {dst_file}")
            shutil.copy(src_file, dst_file)

split_data('demented')
split_data('non-demented')


Found 86 files in demented
Copying ./demented/OAS2_0159_MR2 to dataset/train/demented/OAS2_0159_MR2
Copying ./demented/OAS2_0108_MR1 to dataset/train/demented/OAS2_0108_MR1
Copying ./demented/OAS2_0144_MR1 to dataset/train/demented/OAS2_0144_MR1
Copying ./demented/OAS2_0176_MR3 to dataset/train/demented/OAS2_0176_MR3
Copying ./demented/OAS2_0184_MR2 to dataset/train/demented/OAS2_0184_MR2
Copying ./demented/OAS2_0102_MR3 to dataset/train/demented/OAS2_0102_MR3
Copying ./demented/OAS2_0139_MR2 to dataset/train/demented/OAS2_0139_MR2
Copying ./demented/OAS2_0165_MR1 to dataset/train/demented/OAS2_0165_MR1
Copying ./demented/OAS2_0172_MR1 to dataset/train/demented/OAS2_0172_MR1
Copying ./demented/OAS2_0140_MR3 to dataset/train/demented/OAS2_0140_MR3
Copying ./demented/OAS2_0172_MR2 to dataset/train/demented/OAS2_0172_MR2
Copying ./demented/OAS2_0127_MR2 to dataset/train/demented/OAS2_0127_MR2
Copying ./demented/OAS2_0182_MR1 to dataset/train/demented/OAS2_0182_MR1
Copying ./demented/OAS2_

In [3]:
train_root = "dataset/train"
val_root = "dataset/val"

train_samples = []
val_samples = []

for label_idx, class_name in enumerate(sorted(os.listdir(train_root))):
    class_path = os.path.join(train_root, class_name)
    for fname in os.listdir(class_path):
        train_samples.append((os.path.join(class_path, fname), label_idx))

for label_idx, class_name in enumerate(sorted(os.listdir(val_root))):
    class_path = os.path.join(val_root, class_name)
    for fname in os.listdir(class_path):
        val_samples.append((os.path.join(class_path, fname), label_idx))



In [4]:
# Define data transformations for data augmentation and normalization
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=10),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),  # shape: [1, H, W]
    transforms.Normalize(mean=[0.5], std=[0.5])
])


transform_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])


In [5]:
test_path = train_samples[0][0]
test_array = np.load(test_path)

print("Shape:", test_array.shape)
print("Dtype:", test_array.dtype)
print("Min:", test_array.min(), "Max:", test_array.max())

Shape: (20, 256, 115)
Dtype: float64
Min: -807.7960675865115 Max: 1642.5774777179804


In [6]:
class_names = sorted(os.listdir(train_root))  # ['demented', 'non-demented']
print(class_names)

def load_sample(path, transform=None):
    volume = np.load(path)  # Shape: (slices, height, width)
    
    # Get the center slice index
    center = volume.shape[0] // 2
    start = max(center - 10, 0)
    end = min(center + 11, volume.shape[0])  # +11 for inclusive indexing

    # Pad if necessary (at edges)
    if end - start < 21:
        padding_needed = 21 - (end - start)
        pad_before = padding_needed // 2
        pad_after = padding_needed - pad_before
        volume = np.pad(volume, ((pad_before, pad_after), (0, 0), (0, 0)), mode='constant')
        start = center - 10
        end = center + 11

    # Extract the 21 slices
    slices = volume[start:end]  # Shape: (21, H, W)

    # Normalize to [0, 1]
    slices = (slices - slices.min()) / (slices.max() - slices.min())

    # Convert to float32 tensor
    slices_tensor = torch.tensor(slices, dtype=torch.float32)

    if transform:
        # Convert each slice to PIL image and to 3-channel RGB
        slices_tensor = torch.stack([
            transform(Image.fromarray((s.numpy() * 255).astype(np.uint8)).convert('RGB'))
            for s in slices_tensor
        ])

    return slices_tensor

# Turn sample lists into datasets
train_dataset = [(load_sample(p, transform_train), label) for p, label in train_samples]
val_dataset = [(load_sample(p, transform_val), label) for p, label in val_samples]

['demented', 'non-demented']


In [7]:
def process_sample(p, label, transform):
    tensor = load_sample(p, transform)  # (21, 3, 224, 224)
    return tensor, label

train_data = [process_sample(p, label, transform_train) for p, label in train_samples]
val_data = [process_sample(p, label, transform_val) for p, label in val_samples]

train_inputs = torch.stack([x[0] for x in train_data])  # (N, 21, 3, 224, 224)
train_labels = torch.tensor([x[1] for x in train_data])

val_inputs = torch.stack([x[0] for x in val_data])
val_labels = torch.tensor([x[1] for x in val_data])

# Wrap in TensorDataset
train_dataset = torch.utils.data.TensorDataset(train_inputs, train_labels)
val_dataset = torch.utils.data.TensorDataset(val_inputs, val_labels)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)

In [8]:
batch = next(iter(train_loader))
print(batch[0].shape)

torch.Size([4, 21, 3, 224, 224])


In [36]:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [14]:
import torchvision.models.video as video_models  # assuming you're using 3D ResNet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained 3D ResNet
model = video_models.r3d_18(pretrained=True)

# Replace final layer for binary classification
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)  # Binary classification

# Unfreeze only the last ResNet block + FC layer
for name, param in model.named_parameters():
    if "layer4" in name or "fc" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

model = model.to(device)
print("Using device:", device)

# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer with only trainable parameters
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-4,
    weight_decay=1e-5
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)


Using device: cuda


In [None]:
num_epochs = 20

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("=" * 40)

    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()
            loader = train_loader
        else:
            model.eval()
            loader = val_loader

        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in loader:
            # Reshape (B, 21, 3, 224, 224) → (B, 3, 21, 224, 224) for ResNet3D
            inputs = inputs.permute(0, 2, 1, 3, 4).to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(loader.dataset)
        epoch_acc = running_corrects.double() / len(loader.dataset)

        print(f"{phase.capitalize()} Loss: {epoch_loss:.4f} | Accuracy: {epoch_acc:.4f}")

    scheduler.step()



Epoch 1/50
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM
⚠️ Skipping batch due to OOM


AttributeError: 'int' object has no attribute 'double'

In [20]:
torch.cuda.empty_cache()


In [49]:
# Save the model
torch.save(model.state_dict(), 'flower_classification_model.pth')

In [42]:
import os
import concurrent.futures
from PIL import Image

dataset_path = "dataset/train"  # Change this to your dataset path
num_threads = os.cpu_count()  # Use all CPU cores

def check_and_remove_image(file_path):
    """Verifies image and removes it if corrupted."""
    try:
        with Image.open(file_path) as img:
            img.verify()  # Verify without fully loading
    except Exception:
        print(f"Removing corrupted image: {file_path}")
        os.remove(file_path)  # Delete invalid image

# Get all image file paths from subdirectories
image_files = []
for root, _, files in os.walk(dataset_path):
    for file in files:
        image_files.append(os.path.join(root, file))

# Run in parallel for speed
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
    executor.map(check_and_remove_image, image_files)

print("✅ Image cleanup complete!")


Removing corrupted image: dataset/train\daisy\4534460263_8e9611db3c_n_jpg.rf.2756de378a7a041406b7aa661912da99.jpg
✅ Image cleanup complete!


In [46]:
# Inspect one batch
inputs, labels = next(iter(train_loader))
print("Input shape:", inputs.shape)


Input shape: torch.Size([16, 3, 224, 224])


In [43]:
del image_datasets, dataloaders  # Clear previous dataset objects

## Classification on unseen image

In [51]:
import torch
from torchvision import models, transforms
from PIL import Image

# Load the saved model
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 1000) 
model.load_state_dict(torch.load("flower_classification_model.pth"))
model.eval()

# Create a new model with the correct final layer
new_model = models.resnet18(pretrained=True)
new_model.fc = nn.Linear(new_model.fc.in_features, 2)

# Copy the weights and biases from the loaded model to the new model
new_model.fc.weight.data = model.fc.weight.data[0:2] # Copy only the first 2 output units
new_model.fc.bias.data = model.fc.bias.data[0:2]



In [53]:
# load and preprocess unseen image
image_path = 'test.jpg'
image = Image.open(image_path)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0) # Add a batch dimension

In [54]:
# Perform inference
with torch.no_grad():
    output = model(input_batch)

# Get the predicted class
_, predicted_class = output.max(1)

# Map the predicted class to the class name
class_names = ['daisy', 'dandelion']
predicted_class_name = class_names[predicted_class.item()]

print(f'The predicted class is: {predicted_class_name}')

The predicted class is: daisy


In [3]:
import numpy as np

file_path = "/home/omanaokar/Desktop/Alzeimers-detection/demented/OAS2_0102_MR1"

file = np.load(file_path)
file.shape

(20, 256, 115)