In [None]:
import torch
print(torch.__version__)  
print(torch.cuda.is_available())  
print(torch.cuda.current_device())  
print(torch.cuda.get_device_name(0)) 

In [None]:
%pip install opencv-python-headless
%pip install opencv-python
%pip install scikit-image
%pip install scikit-learn
%pip install tqdm
%pip install sympy==1.13.3

In [None]:
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import torch.nn as nn
import torchvision.models as models

class ResNetModel(nn.Module):
    def __init__(self, pretrained=True, num_classes=5):
        super(ResNetModel, self).__init__()
        self.resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None)

        for name, param in self.resnet.named_parameters():
            if "layer3" in name or "layer4" in name or "fc" in name:
                param.requires_grad = True
            else:
                param.requires_grad = False

        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Linear(in_features, 1024),
            nn.BatchNorm1d(1024),
            nn.SiLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.SiLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.resnet(x)



class ResNet101_MRI(nn.Module):
    def __init__(self, pretrained=True, num_classes=5):
        super(ResNet101_MRI, self).__init__()
        self.resnet = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1 if pretrained else None)
        for name, param in self.resnet.named_parameters():
            if "layer2" in name or "layer3" in name or "layer4" in name or "fc" in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Linear(in_features, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.resnet(x)


In [None]:
def load_model(path, model_class):
    model = model_class(pretrained=False, num_classes=5)
    model.load_state_dict(torch.load(path, map_location=device))
    return model.to(device).eval()

model_axial_r50 = load_model("ResNet_Alzheimer_Axial_Multiclass.pth", ResNetModel)
model_axial_r101 = load_model("ResNet101_Alzheimer_Axial_Multiclass.pth", ResNet101_MRI)
model_sag_r50 = load_model("ResNet_Alzheimer_Sagittal_Multiclass.pth", ResNetModel)
model_sag_r101 = load_model("ResNet101_Alzheimer_Sagittal_Multiclass.pth", ResNet101_MRI)


In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, ConcatDataset

train_dir1 = 'D:/Licenta/Datasets/ADNI_Oficial/Processed/Axial/Train/'
train_dir2 = 'D:/Licenta/Datasets/ADNI_Oficial/Filtered/Sagittal/Train/'

simple_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

dataset1 = datasets.ImageFolder(root=train_dir1, transform=simple_transform)
dataset2 = datasets.ImageFolder(root=train_dir2, transform=simple_transform)

combined_dataset = ConcatDataset([dataset1, dataset2])
combined_loader = DataLoader(combined_dataset, batch_size=64, shuffle=False)

mean = 0.
std = 0.
nb_samples = 0.

for data, _ in combined_loader:
    batch_samples = data.size(0)
    data = data.view(batch_samples, data.size(1), -1)
    mean += data.mean(2).sum(0)
    std += data.std(2).sum(0)
    nb_samples += batch_samples

mean /= nb_samples
std /= nb_samples

print(f"Mean: {mean}")
print(f"Std: {std}")


In [None]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.2192, 0.2192, 0.2192], std=[0.2474, 0.2474, 0.2474])  
])

class_names = ["AD", "CN", "EMCI", "LMCI", "MCI"]

def predict_dual_view(axial_path, sagittal_path, w_a50=0.25, w_a101=0.25, w_s50=0.25, w_s101=0.25):
    axial_img = Image.open(axial_path).convert("RGB")
    sagittal_img = Image.open(sagittal_path).convert("RGB")

    axial_tensor = transform(axial_img).unsqueeze(0).to(device)
    sagittal_tensor = transform(sagittal_img).unsqueeze(0).to(device)

    with torch.no_grad(), torch.amp.autocast(device_type=device.type, enabled=device.type == 'cuda'):
        out_a50 = model_axial_r50(axial_tensor)
        out_a101 = model_axial_r101(axial_tensor)
        out_s50 = model_sag_r50(sagittal_tensor)
        out_s101 = model_sag_r101(sagittal_tensor)

        final_output = (
            w_a50 * out_a50 +
            w_a101 * out_a101 +
            w_s50 * out_s50 +
            w_s101 * out_s101
        )

        softmax_scores = torch.nn.functional.softmax(final_output, dim=1)
        pred_idx = torch.argmax(softmax_scores, dim=1).item()

    return pred_idx, softmax_scores


In [None]:
import os

def find_matching_pairs(axial_root, sagittal_root):
    matches = []

    for label in os.listdir(axial_root):
        axial_class_dir = os.path.join(axial_root, label)
        sagittal_class_dir = os.path.join(sagittal_root, label)

        if not os.path.isdir(axial_class_dir) or not os.path.isdir(sagittal_class_dir):
            continue

        axial_files = os.listdir(axial_class_dir)
        sagittal_files = os.listdir(sagittal_class_dir)

        sagittal_set = set(sagittal_files)

        for afile in axial_files:
            expected_sagittal = afile.replace("Axial", "Sagittal")
            if expected_sagittal in sagittal_set:
                axial_path = os.path.join(axial_class_dir, afile)
                sagittal_path = os.path.join(sagittal_class_dir, expected_sagittal)
                matches.append((label, axial_path, sagittal_path))

    return matches


In [None]:
import csv

axial_dir = "D:/Licenta/Datasets/ADNI_Oficial/Processed/Axial/Test/"
sagittal_dir = "D:/Licenta/Datasets/ADNI_Oficial/Filtered/Sagittal/Test/"
output_csv = "D:/Licenta/rezultate_dualview.csv"

pairs = find_matching_pairs(axial_dir, sagittal_dir)
print(f"Total perechi găsite: {len(pairs)}\n")

header = ["TrueLabel", "PredictedLabel", "AxialPath", "SagittalPath"] + class_names

rows = []
correct = 0

for idx, (true_label, axial_img_path, sagittal_img_path) in enumerate(pairs):
    pred_idx, probs = predict_dual_view(axial_img_path, sagittal_img_path)
    predicted_label = class_names[pred_idx]
    prob_values = [round(p.item(), 4) for p in probs[0]]

    rows.append([true_label, predicted_label, axial_img_path, sagittal_img_path] + prob_values)

    if predicted_label == true_label:
        correct += 1

    print(f"#{idx+1}")
    print(f"GT : {true_label}")
    print(f"Pred: {predicted_label}")
    for i, val in enumerate(prob_values):
        print(f"  {class_names[i]:<5}: {val:.4f}")
    print("-" * 40)

acc = correct / len(pairs) * 100
print(f"\n Accuracy pe perechi: {acc:.2f}%")

with open(output_csv, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(header)
    writer.writerows(rows)

print(f"\n Rezultatele au fost salvate în:\n{output_csv}")


In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, accuracy_score
import seaborn as sns
import numpy as np

true_labels = [label for label, _, _ in pairs]
pred_labels = [row[1] for row in rows]  

cm = confusion_matrix(true_labels, pred_labels, labels=class_names)

class_accuracy = cm.diagonal() / cm.sum(axis=1)

plt.figure(figsize=(8, 5))
bars = plt.bar(class_names, class_accuracy * 100, color='skyblue')
plt.ylim(0, 105)
plt.ylabel("Accuracy (%)")
plt.title("Accuracy pe clasă (dual-view ensemble)")

for bar, acc in zip(bars, class_accuracy):
    plt.text(bar.get_x() + bar.get_width()/2, acc * 100 + 1, f"{acc * 100:.2f}%", ha='center')

plt.grid(axis='y', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion matrix")
plt.tight_layout()
plt.show()

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class DualViewResNet(nn.Module):
    def __init__(self, num_classes=5, pretrained=True):
        super(DualViewResNet, self).__init__()

        self.resnet_axial = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None)
        self.resnet_sagittal = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None)

        self.resnet_axial.fc = nn.Identity()
        self.resnet_sagittal.fc = nn.Identity()

        self.classifier = nn.Sequential(
            nn.Linear(4096, 512), 
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(512, num_classes)
        )

    def forward(self, axial_img, sagittal_img):
        feat_axial = self.resnet_axial(axial_img)
        feat_sagittal = self.resnet_sagittal(sagittal_img)

        combined = torch.cat((feat_axial, feat_sagittal), dim=1)
        output = self.classifier(combined)
        return output


In [None]:
from torch.utils.data import Dataset
from PIL import Image
import os

class DualViewDataset(Dataset):
    def __init__(self, pairs, transform=None, class_to_idx=None):
        self.pairs = pairs
        self.transform = transform
        self.class_to_idx = class_to_idx or {name: idx for idx, name in enumerate(sorted({label for label, _, _ in pairs}))}

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        label, axial_path, sagittal_path = self.pairs[idx]

        assert os.path.exists(axial_path), f"Axial image not found: {axial_path}"
        assert os.path.exists(sagittal_path), f"Sagittal image not found: {sagittal_path}"

        axial_img = Image.open(axial_path).convert('RGB')
        sagittal_img = Image.open(sagittal_path).convert('RGB')

        if self.transform:
            axial_img = self.transform(axial_img)
            sagittal_img = self.transform(sagittal_img)

        label_idx = self.class_to_idx[label]
        return axial_img, sagittal_img, label_idx


In [None]:
from torchvision import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.2192]*3, std=[0.2474]*3)
])

axial_dir = "D:/Licenta/Datasets/ADNI_Oficial/Processed/Axial/Train/"
sagittal_dir = "D:/Licenta/Datasets/ADNI_Oficial/Filtered/Sagittal/Train/"
pairs = find_matching_pairs(axial_dir, sagittal_dir)

assert len(pairs) > 0, "No matching image pairs found!"

dual_dataset = DualViewDataset(pairs, transform=transform)
dual_loader = DataLoader(dual_dataset, batch_size=64, shuffle=True, num_workers=0, pin_memory=True)


In [None]:
model = DualViewResNet(num_classes=5, pretrained=True).to(device)

In [None]:
from torch.cuda.amp import GradScaler, autocast
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
import copy

def train_dualview_model(model, train_loader, val_loader, num_epochs=10, lr=1e-4, patience=7, save_path="best_dualview.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2, eta_min=1e-6)
    scaler = GradScaler() 
    criterion = torch.nn.CrossEntropyLoss()

    best_loss = float("inf")
    best_model = None
    no_improve_epochs = 0
    train_losses, val_losses, train_accs, val_accs = [], [], [], []

    for epoch in range(num_epochs):
        model.train()
        running_loss, preds_all, labels_all = 0.0, [], []

        for axial_img, sagittal_img, labels in train_loader:
            axial_img, sagittal_img, labels = axial_img.to(device), sagittal_img.to(device), labels.to(device)

            optimizer.zero_grad()
            with autocast():
                outputs = model(axial_img, sagittal_img)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()
            preds_all += torch.argmax(outputs, dim=1).cpu().tolist()
            labels_all += labels.cpu().tolist()

        train_loss = running_loss / len(train_loader)
        train_acc = accuracy_score(labels_all, preds_all)
        train_losses.append(train_loss)
        train_accs.append(train_acc)

        model.eval()
        val_running_loss, val_preds, val_labels = 0.0, [], []
        with torch.no_grad():
            for axial_img, sagittal_img, labels in val_loader:
                axial_img, sagittal_img, labels = axial_img.to(device), sagittal_img.to(device), labels.to(device)
                with autocast():
                    outputs = model(axial_img, sagittal_img)
                    loss = criterion(outputs, labels)

                val_running_loss += loss.item()
                val_preds += torch.argmax(outputs, dim=1).cpu().tolist()
                val_labels += labels.cpu().tolist()

        val_loss = val_running_loss / len(val_loader)
        val_acc = accuracy_score(val_labels, val_preds)
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        scheduler.step(epoch + 1)

        print(f"Epoch {epoch+1}/{num_epochs} — Train Loss: {train_loss:.4f}, Acc: {train_acc*100:.2f}% | Val Loss: {val_loss:.4f}, Acc: {val_acc*100:.2f}%")

        if val_loss < best_loss:
            best_loss = val_loss
            best_model = copy.deepcopy(model.state_dict())
            torch.save(best_model, save_path)
            print(f"  Saved new best model at epoch {epoch+1}")
            no_improve_epochs = 0
        else:
            no_improve_epochs += 1
            if no_improve_epochs >= patience:
                print("Early stopping triggered.")
                break

    print(f"\nTraining finished. Best model saved to: {save_path}")
    return train_losses, val_losses, train_accs, val_accs


In [None]:
from sklearn.model_selection import train_test_split
train_pairs, val_pairs = train_test_split(pairs, test_size=0.2, stratify=[x[0] for x in pairs], random_state=42)

train_set = DualViewDataset(train_pairs, transform)
val_set = DualViewDataset(val_pairs, transform)

train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=0, pin_memory=True)

model = DualViewResNet(num_classes=5, pretrained=True)

train_losses, val_losses, train_accs, val_accs = train_dualview_model(
    model,
    train_loader,
    val_loader,
    num_epochs=20,
    lr=5e-5,
    patience=10,
    save_path="best_dualview_model.pth"
)


In [None]:
import matplotlib.pyplot as plt

def plot_training_history(train_losses, val_losses, train_accs, val_accs):
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(14, 6), dpi=100)

    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Train Loss', color='red')
    plt.plot(epochs, val_losses, label='Validation Loss', color='orange')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Loss per Epoch")
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(epochs, [a * 100 for a in train_accs], label='Train Accuracy', color='green')
    plt.plot(epochs, [a * 100 for a in val_accs], label='Validation Accuracy', color='blue')
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.title("Accuracy per Epoch")
    plt.legend()
    plt.grid(True)

    plt.suptitle("DualView Model Training Progress", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

In [None]:
plot_training_history(train_losses, val_losses, train_accs, val_accs)

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, accuracy_score
import matplotlib.pyplot as plt
import torch

def evaluate_and_plot_confusion_matrix(model, test_loader, class_names, checkpoint_path="best_dualview_model.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.to(device)
    model.eval()

    all_preds, all_labels = [], []

    with torch.no_grad():
        for axial_img, sagittal_img, labels in test_loader:
            axial_img, sagittal_img = axial_img.to(device), sagittal_img.to(device)
            outputs = model(axial_img, sagittal_img)
            preds = torch.argmax(outputs, dim=1)
            all_preds += preds.cpu().tolist()
            all_labels += labels.tolist()

    cm = confusion_matrix(all_labels, all_preds)
    acc = accuracy_score(all_labels, all_preds)
    print(f"Test Accuracy: {acc*100:.2f}%")

    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    disp.plot(cmap=plt.cm.Blues, values_format='d')
    plt.title("Confusion Matrix")
    plt.show()

    return cm


In [None]:
class_names = ["AD", "CN", "EMCI", "LMCI", "MCI"]  

model = DualViewResNet(num_classes=5, pretrained=False)
evaluate_and_plot_confusion_matrix(model, val_loader, class_names)


In [None]:
import random
import matplotlib.pyplot as plt
import torch
import numpy as np

class_names = ['AD', 'CN', 'EMCI', 'LMCI', 'MCI']

def show_random_dual_predictions(model, test_dataset, num_images=12, cols=4, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.eval()
    model.to(device)

    total_images = len(test_dataset)
    num_images = min(num_images, total_images)
    indices = random.sample(range(total_images), num_images)

    correct_list = []
    wrong_list = []

    for idx in indices:
        axial_img, sagittal_img, label = test_dataset[idx]
        input_axial = axial_img.unsqueeze(0).to(device)
        input_sagittal = sagittal_img.unsqueeze(0).to(device)

        with torch.no_grad(), torch.amp.autocast(device_type=device.type, enabled=device.type == 'cuda'):
            outputs = model(input_axial, input_sagittal)
            _, pred = torch.max(outputs, 1)

        pred_idx = pred.item()

        def denorm(tensor_img):
            img = tensor_img.cpu().numpy().transpose((1, 2, 0))
            mean = np.array([0.2192]*3)
            std = np.array([0.2474]*3)
            img = img * std + mean
            img = np.clip(img, 0, 1)
            return img

        entry = {
            'axial': denorm(axial_img),
            'sagittal': denorm(sagittal_img),
            'pred': class_names[pred_idx],
            'true': class_names[label],
            'is_correct': pred_idx == label
        }

        if entry['is_correct']:
            correct_list.append(entry)
        else:
            wrong_list.append(entry)

    all_entries = wrong_list + correct_list
    total = len(all_entries)
    rows = total
    fig, axes = plt.subplots(rows, 2, figsize=(8, rows * 3))

    if rows == 1: 
        axes = [axes]

    for i, entry in enumerate(all_entries):
       
        axes[i][0].imshow(entry['axial'])
        axes[i][0].set_title(f"Axial\nPred: {entry['pred']} | True: {entry['true']}",
                             color='green' if entry['is_correct'] else 'red')
        axes[i][0].axis('off')

        axes[i][1].imshow(entry['sagittal'])
        axes[i][1].set_title(f"Sagittal\nPred: {entry['pred']} | True: {entry['true']}",
                             color='green' if entry['is_correct'] else 'red')
        axes[i][1].axis('off')

    plt.tight_layout()
    plt.show()

    print(f"\nSummary:")
    print(f"Total image pairs shown: {total}")
    print(f"Correct predictions: {len(correct_list)}")
    print(f"Wrong predictions:  {len(wrong_list)}")

    if wrong_list:
        print("\n Wrong predictions detail:")
        for idx, entry in enumerate(wrong_list, start=1):
            print(f"  {idx}. Predicted: {entry['pred']} | Actual: {entry['true']}")


In [None]:
show_random_dual_predictions(model, val_set, num_images=100)