In [None]:
!pip install dicom2nifti
!pip install monai nibabel torch torchvision tqdm

In [None]:
import os
import numpy as np
import pandas as pd

import glob, shutil, tempfile
import dicom2nifti
import dicom2nifti.settings as settings
import torch
import nibabel as nib
from monai.transforms import (
    LoadImage,
    EnsureChannelFirst,
    Spacing,
    Orientation,
    ScaleIntensity,
    Resize,
    ToTensor
)
from monai.networks.nets import resnet
from tqdm import tqdm
import torch.nn.functional as F

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import nibabel as nib

# To convert DICOM files into NifTi

In [None]:
DATASET_DIR = "/kaggle/input/data-ad/Dataset"   
OUTPUT_DIR  = "/kaggle/working/NIFTI_DATA99" 
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
settings.disable_validate_orthogonal() 

In [None]:
def convert_subject(dicom_dir, out_dir, sub_id):
    
    with tempfile.TemporaryDirectory() as tmp_out:
        
        dicom2nifti.convert_directory(dicom_dir, tmp_out, compression=True, reorient=True)

        
        nii_files = glob.glob(os.path.join(tmp_out, "*.nii.gz"))
        src_file = nii_files[0]
        dst_file = os.path.join(out_dir, f"{sub_id}.nii.gz")
        shutil.move(src_file, dst_file)
        # print(f"{sub_id} to {dst_file}")

In [None]:
for cls in os.listdir(DATASET_DIR):  
    cls_path = os.path.join(DATASET_DIR, cls)
    if not os.path.isdir(cls_path):
        continue

    out_cls_dir = os.path.join(OUTPUT_DIR, cls)
    os.makedirs(out_cls_dir, exist_ok=True)

    for subj in os.listdir(cls_path):
        subj_path = os.path.join(cls_path, subj)
        if not os.path.isdir(subj_path):
            continue

        dicom_files = glob.glob(os.path.join(subj_path, "**", "*.dcm"), recursive=True)

        dicom_dir = os.path.dirname(dicom_files[0])
        convert_subject(dicom_dir, out_cls_dir, subj)

In [None]:
!ls /kaggle/working/NIFTI_DATA/ADNI1_T1w_Cohort_AD_Visit_12_MRI | wc -l

In [None]:
!zip -r NIFTI_DATA99.zip /kaggle/working/NIFTI_DATA99


# Now will use converted nifti files to finetune and extract the MRI img

In [None]:
DATA_DIR = "/kaggle/input/nifti-data99/kaggle/working/NIFTI_DATA99"
EMBEDDING_SAVE_DIR = "/kaggle/working/MRI_Embeddings"
os.makedirs(EMBEDDING_SAVE_DIR, exist_ok=True)

In [None]:
from torch.cuda.amp import GradScaler, autocast
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
num_classes = 3  
model = resnet.resnet10(spatial_dims=3, n_input_channels=1, num_classes=1000)


In [None]:
from collections import OrderedDict
checkpoint = torch.load("/kaggle/input/resnet-10-23/pytorch/default/1/resnet_10_23dataset.pth", map_location=device)
state_dict = checkpoint['state_dict'] 

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k.replace("module.", "") 
    new_state_dict[name] = v

model.load_state_dict(new_state_dict, strict=False)

In [None]:
in_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(0.2),
    nn.Linear(in_features, num_classes)
)

In [None]:
print(model.conv1.weight.shape)
print(new_state_dict['conv1.weight'].shape)

In [None]:
in_features

In [None]:
model.fc

In [None]:
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)

model = model.to(device)


In [None]:
for param in model.parameters():
    param.requires_grad = True

In [None]:
unfreeze_layers = ["layer4", "fc"]
for name, param in model.named_parameters():
    if any(l in name for l in unfreeze_layers):
        param.requires_grad = True
    else:
        param.requires_grad = False
        

In [None]:
trainable_params = 0
total_params = 0

for name, param in model.named_parameters():
    total_params += param.numel()
    if param.requires_grad:
        trainable_params += param.numel()
        if "conv1" in name or ".0.conv1" in name or "fc" in name:
            print(f"TRAIN: {name}")
    else:
        pass 

print(f"{trainable_params:,} of {total_params:,}")

In [None]:
model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0)
#optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4, weight_decay=0)
criterion = nn.CrossEntropyLoss()
scaler = torch.amp.GradScaler()

In [None]:
from torch.utils.data import Dataset, DataLoader
import gc

In [None]:
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    ScaleIntensityd,
    NormalizeIntensityd,
    Resized,
    EnsureTyped
)

In [None]:
from monai.transforms import Spacingd, Orientationd
preprocess = Compose([
    LoadImaged(keys=["image"]),
    EnsureChannelFirstd(keys=["image"]),
    Spacingd(keys=["image"], pixdim=(1.0,1.0,1.0), mode="bilinear"),
    Orientationd(keys=["image"], axcodes="RAS"),
    NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
    Resized(keys=["image"], spatial_size=(128,128,128)),
    EnsureTyped(keys=["image"]),
])


In [None]:
import os
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [None]:
class MRIDataset(Dataset):
    def __init__(self, files, transform):
        self.files = files
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        data = dict(self.files[idx])
        data = self.transform(data)

        image = data["image"]
        label = data.get("label", None)

        if not torch.is_tensor(label):
            label = torch.tensor(label, dtype=torch.long)
        else:
            label = label.long()

        return image, label


In [None]:
all_files = []
class_to_idx = {
    "ADNI1_T1w_Cohort_AD_Visit_12_MRI": 0, 
    "ADNI1_T1w_DXMCI=1_at_m18_MRI": 1, 
    "ADNI1_T1w_Normal_at_m12_MRI": 2
}



In [None]:
for cls_name, idx in class_to_idx.items():
    cls_path = os.path.join(DATA_DIR, cls_name)
    if os.path.exists(cls_path):
        for img in os.listdir(cls_path):
            if img.endswith('.nii') or img.endswith('.nii.gz'):
                all_files.append({"image": os.path.join(cls_path, img), "label": idx})

In [None]:
train_files, val_files = train_test_split(all_files, test_size=0.25, stratify=[x['label'] for x in all_files], random_state=42)

train_ds = MRIDataset(train_files, transform=preprocess)
val_ds = MRIDataset(val_files, transform=preprocess)

train_loader = DataLoader(train_ds, batch_size=20, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=20, shuffle=False, num_workers=4, pin_memory=True)

print(len(train_ds))
print(len(val_ds))

In [None]:
batch = next(iter(train_loader))
images, labels = batch
print("images.shape", images.shape)
print("images.dtype", images.dtype, "min/max/mean/std:", images.min().item(), images.max().item(), images.mean().item(), images.std().item())
print("labels.shape", labels.shape, "labels.dtype", labels.dtype, "unique:", torch.unique(labels))

In [None]:
torch.cuda.empty_cache()
gc.collect()

In [None]:
# Test for overfitting through 1 batch test
images, labels = next(iter(train_loader))
images, labels = images.to(device), labels.to(device)

model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0) 

for i in range(50): 
    optimizer.zero_grad()
    with torch.cuda.amp.autocast():
        outputs = model(images)
        loss = criterion(outputs, labels)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    
    if i % 10 == 0:
        _, preds = torch.max(outputs, 1)
        acc = (preds == labels).sum().item() / labels.size(0)
        print(f"Step {i}: Loss {loss.item():.4f} | Acc: {acc*100:.1f}%")

In [None]:
EPOCHS = 50
best_metric = -1

for epoch in range(EPOCHS):
    model.train() 
    train_loss, train_correct, train_total = 0, 0, 0
    
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
        images, labels = batch
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        with autocast(): 
            outputs = model(images)
            loss = criterion(outputs, labels)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()


    model.eval() 
    val_loss, val_correct, val_total = 0, 0, 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} Validation"):
            images, labels = batch
            images, labels = images.to(device), labels.to(device)
            
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
    

    train_acc = 100 * train_correct / train_total
    val_acc = 100 * val_correct / val_total
    
    print(f"Epoch {epoch+1}: Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")
    
 
    if val_acc > best_metric:
        best_metric = val_acc
        torch.save(model.module.state_dict(), "best_model.pth")
        print("New Best")

In [None]:
# fOR SINGLE BATCH
data_iter = iter(train_loader)
img, lbl = next(data_iter)

print(f"Img Batch Shape: {img.shape}")
print(f"Lbl Batch: {lbl}")
print(f"Img Max Value: {img.max().item()}")
print(f"Img Min Value: {img.min().item()}")

In [None]:
#torch.save(model.state_dict(), "finetuned_medicalnet_resnet50_pro.pth")

Finetuned model will be used for extracting embeddings 

In [None]:
model = resnet.resnet10(spatial_dims=3, n_input_channels=1, num_classes=3)
state_dict = torch.load("/kaggle/working/best_model.pth", map_location=device)
model.load_state_dict(state_dict, strict=False)

In [None]:
model = model.to(device)
model.eval()

feature_extractor = nn.Sequential(*list(model.children())[:-1])

In [None]:
def extract_embedding(img_tensor):
    with torch.no_grad():
        feat = feature_extractor(img_tensor_5d)
        feat = feat.view(feat.size(0), -1)              
    return feat.cpu().numpy().squeeze()

In [None]:
for cls in os.listdir(DATA_DIR):
    cls_dir = os.path.join(DATA_DIR, cls)
    if not os.path.isdir(cls_dir):
        continue

    save_cls_dir = os.path.join(EMBEDDING_SAVE_DIR, cls)
    os.makedirs(save_cls_dir, exist_ok=True)

    nii_files = [f for f in os.listdir(cls_dir) if f.endswith((".nii", ".nii.gz"))]

    for file in tqdm(nii_files, desc=f"Processing {cls}"):
        
        img_path = os.path.join(cls_dir, file)
        img = nib.load(img_path).get_fdata()
        img = np.nan_to_num(img)

        #if img.max() > 0:
        #    img = img / img.max()
        
        img_tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  
        
        img_tensor = img_tensor.to(device)
        img_tensor = F.interpolate(img_tensor, size=(128, 128, 128), mode="trilinear", align_corners=False)

     
        embedding = extract_embedding(img_tensor)

    
        if file.endswith(".nii.gz"):
            subj_id = file[:-7]
        elif file.endswith(".nii"):
            subj_id = file[:-4]
        else:
            subj_id = os.path.splitext(file)[0]

        np.save(os.path.join(save_cls_dir, f"{subj_id}_embedding.npy"), embedding)

In [None]:
MRI_DIR = "/kaggle/working/MRI_Embeddings"
classes = ["ADNI1_T1w_Cohort_AD_Visit_12_MRI", "ADNI1_T1w_DXMCI=1_at_m18_MRI", "ADNI1_T1w_Normal_at_m12_MRI"]
records = []

In [None]:
MRI_DIR

In [None]:
for cls in classes:
    cls_path = os.path.join(MRI_DIR, cls)
    for f in os.listdir(cls_path):
        if f.endswith(".npy"):
            subj_id = os.path.splitext(f)[0]
            emb = np.load(os.path.join(cls_path, f))
            records.append({"subject_id": subj_id, "label": cls, "embedding": emb})

In [None]:
mri_df = pd.DataFrame(records)
mri_df

In [None]:
X_mri = np.stack([x.flatten() for x in mri_df['embedding']])
y = mri_df['label']
subjects = mri_df['subject_id']

In [None]:
print(X_mri.shape)

In [None]:
from sklearn.preprocessing import StandardScaler, RobustScaler

scaler = RobustScaler()
X_mri= scaler.fit_transform(X_mri)

In [None]:
from sklearn.manifold import TSNE
import seaborn as sns


tsne = TSNE(n_components=2, perplexity=30, random_state=42, init='pca', learning_rate='auto')
X_tsne = tsne.fit_transform(X_mri)

plt.figure(figsize=(8, 6))
sns.scatterplot(x=X_tsne[:,0], y=X_tsne[:,1], hue=y, palette="deep", s=60)
plt.title("MRI Embedding")
plt.show()

In [None]:
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, accuracy_score

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X_mri, y, test_size=0.2, random_state=42, stratify=y)

In [None]:
rf = RandomForestClassifier(n_estimators=400, random_state=42)

In [None]:
rf.fit(X_train, y_train)

In [None]:
y_pred = rf.predict(X_test)

In [None]:
print("MRI accuracy:", accuracy_score(y_test, y_pred))

In [None]:
!zip -r MRI_Embeddings.zip /kaggle/working/MRI_Embeddings