In [1]:
import os
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score
from PIL import Image

def get_patient_paths(directory):
    return [os.path.join(directory, patient_folder) for patient_folder in os.listdir(directory) if os.path.isdir(os.path.join(directory, patient_folder))]

def get_image_paths(patient_folder):
    return [os.path.join(patient_folder, img) for img in os.listdir(patient_folder) if img.endswith('.jpg')]

def copy_patients(src_patients, dst_dir):
    for patient_folder in src_patients:
        patient_name = os.path.basename(patient_folder)
        dst_patient_folder = os.path.join(dst_dir, patient_name)
        if not os.path.exists(dst_patient_folder):
            try:
                shutil.copytree(patient_folder, dst_patient_folder)
            except Exception as e:
                print(f"Error copying {patient_name}: {str(e)}")

benign_dir = "/kaggle/input/iaaa-mri-train-data-partition/iaaa-mri-train-data partition/data/benign"
malignant_dir = "/kaggle/input/iaaa-mri-train-data-partition/iaaa-mri-train-data partition/data/malignant"

benign_patient_folders = get_patient_paths(benign_dir)
malignant_patient_folders = get_patient_paths(malignant_dir)

benign_train, benign_test = train_test_split(benign_patient_folders, test_size=0.05, random_state=42)
malignant_train, malignant_test = train_test_split(malignant_patient_folders, test_size=0.05, random_state=42)

train_dir = "/kaggle/working/train"
test_dir = "/kaggle/working/test"

if not os.path.exists(train_dir):
    os.makedirs(train_dir)
    print(f"Created directory: {train_dir}")
else:
    print(f"Directory already exists: {train_dir}")

if not os.path.exists(test_dir):
    os.makedirs(test_dir)
    print(f"Created directory: {test_dir}")
else:
    print(f"Directory already exists: {test_dir}")

copy_patients(benign_train, os.path.join(train_dir, "benign"))
copy_patients(malignant_train, os.path.join(train_dir, "malignant"))
copy_patients(benign_test, os.path.join(test_dir, "benign"))
copy_patients(malignant_test, os.path.join(test_dir, "malignant"))

train_benign_patients = get_patient_paths(os.path.join(train_dir, "benign"))
train_malignant_patients = get_patient_paths(os.path.join(train_dir, "malignant"))
test_benign_patients = get_patient_paths(os.path.join(test_dir, "benign"))
test_malignant_patients = get_patient_paths(os.path.join(test_dir, "malignant"))

train_patients = train_benign_patients + train_malignant_patients
train_labels = [0] * len(train_benign_patients) + [1] * len(train_malignant_patients)
test_patients = test_benign_patients + test_malignant_patients
test_labels = [0] * len(test_benign_patients) + [1] * len(test_malignant_patients)

train_patients, val_patients, train_labels, val_labels = train_test_split(train_patients, train_labels, test_size=0.2, random_state=42, stratify=train_labels)

print(f"Train patients: {len(train_patients)}, Validation patients: {len(val_patients)}, Test patients: {len(test_patients)}")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])


class BrainMRIDataset(Dataset):
    def __init__(self, patient_folders, labels, transform=None, max_images=20):
        self.patient_folders = patient_folders
        self.labels = labels
        self.transform = transform
        self.max_images = max_images

    def __len__(self):
        return len(self.patient_folders)

    def __getitem__(self, index):
        patient_folder = self.patient_folders[index]
        image_paths = get_image_paths(patient_folder)
        images = []
        
        for img_path in image_paths[:self.max_images]:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            images.append(image)

        if len(images) < self.max_images:
            padding = [torch.zeros_like(images[0]) for _ in range(self.max_images - len(images))]
            images.extend(padding)

        images = torch.stack(images)
        label = self.labels[index]
        return images, label

batch_size = 8

class_sample_count = np.array([train_labels.count(0), train_labels.count(1)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in train_labels])
samples_weight = torch.from_numpy(samples_weight)
sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))

train_dataset = BrainMRIDataset(train_patients, train_labels, transform=transform)
val_dataset = BrainMRIDataset(val_patients, val_labels, transform=transform)
test_dataset = BrainMRIDataset(test_patients, test_labels, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        BCE_loss = nn.BCEWithLogitsLoss(reduction='none')(inputs, targets)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        return F_loss.mean()

class AttentionMIL(nn.Module):
    def __init__(self, num_classes=1):
        super(AttentionMIL, self).__init__()
        self.L = 512
        self.D = 128
        self.K = 1

        self.feature_extractor = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
        self.feature_extractor.fc = nn.Identity()

        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K)
        )

        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        batch_size, num_images, channels, height, width = x.size()
        x = x.view(-1, channels, height, width)
        H = self.feature_extractor(x)
        H = H.view(batch_size, num_images, -1)
        A = self.attention(H)
        A = torch.transpose(A, 2, 1)
        A = nn.functional.softmax(A, dim=2) 
        M = torch.bmm(A, H)
        Y_logit = self.classifier(M.view(batch_size, -1))
        Y_hat = torch.sigmoid(Y_logit)
        
        return Y_logit, Y_hat, A

    def calculate_classification_error(self, X, Y):
        Y = Y.float()
        _, Y_hat, _ = self.forward(X)
        Y_hat = (Y_hat > 0.5).float()
        error = 1. - Y_hat.eq(Y).cpu().float().mean().item()
        return error, Y_hat

model = AttentionMIL()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = FocalLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=5)

def train(epoch):
    model.train()
    train_loss = 0.
    train_correct = 0
    train_total = 0
    y_true = []
    y_pred = []

    for batch_idx, (data, label) in enumerate(train_loader):
        data, label = data.to(device), label.to(device)
        label = label.float().unsqueeze(1)

        optimizer.zero_grad()
        Y_logit, Y_hat, _ = model(data)
        loss = criterion(Y_logit, label)
        train_loss += loss.item()
        
        Y_hat = (Y_hat > 0.5).float()
        train_correct += (Y_hat == label).sum().item()
        train_total += label.size(0)

        y_true.extend(label.cpu().numpy())
        y_pred.extend(Y_hat.cpu().numpy())

        loss.backward()
        optimizer.step()

    train_accuracy = train_correct / train_total
    train_precision = precision_score(y_true, y_pred)
    train_recall = recall_score(y_true, y_pred)
    train_f1 = f1_score(y_true, y_pred)
    train_auc = roc_auc_score(y_true, y_pred)
    
    print(f'Epoch: {epoch}')
    print(f'Train Loss: {train_loss / len(train_loader):.4f}')
    print(f'Train Accuracy: {train_accuracy:.4f}')
    print(f'Train Precision: {train_precision:.4f}')
    print(f'Train Recall: {train_recall:.4f}')
    print(f'Train F1-score: {train_f1:.4f}')
    print(f'Train AUC: {train_auc:.4f}')
    
    return train_loss / len(train_loader), train_accuracy

def validate():
    model.eval()
    val_loss = 0.
    val_correct = 0
    val_total = 0
    y_true = []
    y_pred = []

    with torch.no_grad():
        for data, label in val_loader:
            data, label = data.to(device), label.to(device)
            label = label.float().unsqueeze(1)

            Y_logit, Y_hat, _ = model(data)
            loss = criterion(Y_logit, label)
            val_loss += loss.item()

            Y_hat = (Y_hat > 0.5).float()
            val_correct += (Y_hat == label).sum().item()
            val_total += label.size(0)

            y_true.extend(label.cpu().numpy())
            y_pred.extend(Y_hat.cpu().numpy())

    val_accuracy = val_correct / val_total
    val_precision = precision_score(y_true, y_pred)
    val_recall = recall_score(y_true, y_pred)
    val_f1 = f1_score(y_true, y_pred)
    val_auc = roc_auc_score(y_true, y_pred)
    
    print(f'Validation Loss: {val_loss / len(val_loader):.4f}')
    print(f'Validation Accuracy: {val_accuracy:.4f}')
    print(f'Validation Precision: {val_precision:.4f}')
    print(f'Validation Recall: {val_recall:.4f}')
    print(f'Validation F1-score: {val_f1:.4f}')
    print(f'Validation AUC: {val_auc:.4f}')
    
    return val_loss / len(val_loader), val_accuracy, val_precision, val_recall, val_f1, val_auc

def evaluate_per_patient(model, data_loader):
    model.eval()
    patient_predictions = {}
    patient_true_labels = {}

    with torch.no_grad():
        for data, labels in data_loader:
            data, labels = data.to(device), labels.to(device)
            _, Y_hat, _ = model(data)
            Y_hat = (Y_hat > 0.5).float()

            for i, label in enumerate(labels):
                patient_id = f"Patient_{i}"
                patient_predictions[patient_id] = Y_hat[i].item()
                patient_true_labels[patient_id] = label.item()

    correct_predictions = sum(1 for p_id in patient_predictions if patient_predictions[p_id] == patient_true_labels[p_id])
    accuracy = correct_predictions / len(patient_predictions)

    print(f"Per-patient accuracy: {accuracy:.4f}")
    print("Sample of patient predictions:")
    for i, (p_id, pred) in enumerate(list(patient_predictions.items())[:5]):
        print(f"{p_id}: Predicted: {pred}, True: {patient_true_labels[p_id]}")

print("Evaluating per patient on validation set:")
evaluate_per_patient(model, val_loader)

def test():
    model.eval()
    test_loss = 0.
    test_correct = 0
    test_total = 0
    y_true = []
    y_pred = []

    with torch.no_grad():
        for data, label in test_loader:
            data, label = data.to(device), label.to(device)
            label = label.float().unsqueeze(1)

            Y_logit, Y_hat, _ = model(data)
            loss = criterion(Y_logit, label)
            test_loss += loss.item()

            Y_hat = (Y_hat > 0.5).float()
            test_correct += (Y_hat == label).sum().item()
            test_total += label.size(0)

            y_true.extend(label.cpu().numpy())
            y_pred.extend(Y_hat.cpu().numpy())

    test_accuracy = test_correct / test_total
    test_precision = precision_score(y_true, y_pred)
    test_recall = recall_score(y_true, y_pred)
    test_f1 = f1_score(y_true, y_pred)
    test_auc = roc_auc_score(y_true, y_pred)
    
    print(f'Test Loss: {test_loss / len(test_loader):.4f}')
    print(f'Test Accuracy: {test_accuracy:.4f}')
    print(f'Test Precision: {test_precision:.4f}')
    print(f'Test Recall: {test_recall:.4f}')
    print(f'Test F1-score: {test_f1:.4f}')
    print(f'Test AUC: {test_auc:.4f}')

epochs = 40
best_val_metrics = {
    'accuracy': 0,
    'precision': 0,
    'recall': 0,
    'f1': 0,
    'auc': 0
}

for epoch in range(1, epochs + 1):
    train_loss, train_accuracy = train(epoch)
    val_loss, val_accuracy, val_precision, val_recall, val_f1, val_auc = validate()
    scheduler.step(val_accuracy)
    current_lr = scheduler.optimizer.param_groups[0]['lr']
    print(f'Current learning rate: {current_lr}')

    if (val_accuracy > 0.90 and
        val_precision > 0.70 and
        val_recall > 0.70 and
        val_f1 > 0.75 and
        val_auc > 0.80):
        
        if (val_accuracy > best_val_metrics['accuracy'] or
            val_precision > best_val_metrics['precision'] or
            val_recall > best_val_metrics['recall'] or
            val_f1 > best_val_metrics['f1'] or
            val_auc > best_val_metrics['auc']):
            
            best_val_metrics['accuracy'] = val_accuracy
            best_val_metrics['precision'] = val_precision
            best_val_metrics['recall'] = val_recall
            best_val_metrics['f1'] = val_f1
            best_val_metrics['auc'] = val_auc

            torch.save(model.state_dict(), 'best_model.pth')
            print(f'Model saved at epoch {epoch}')
            print(f'Best validation metrics: {best_val_metrics}')

print("Training completed.")

if os.path.exists('best_model.pth'):
    model.load_state_dict(torch.load('best_model.pth'))
    print("Best model loaded. Evaluating on test set:")
    test()
else:
    print("No model was saved that met the specified criteria.")

Created directory: /kaggle/working/train
Created directory: /kaggle/working/test
Train patients: 2379, Validation patients: 595, Test patients: 158


Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 185MB/s]


Evaluating per patient on validation set:
Per-patient accuracy: 0.8750
Sample of patient predictions:
Patient_0: Predicted: 0.0, True: 0
Patient_1: Predicted: 0.0, True: 0
Patient_2: Predicted: 0.0, True: 0
Patient_3: Predicted: 0.0, True: 0
Patient_4: Predicted: 0.0, True: 0
Epoch: 1
Train Loss: 0.0391
Train Accuracy: 0.6738
Train Precision: 0.6891
Train Recall: 0.6846
Train F1-score: 0.6868
Train AUC: 0.6733
Validation Loss: 0.0248
Validation Accuracy: 0.8420
Validation Precision: 0.3980
Validation Recall: 0.5270
Validation F1-score: 0.4535
Validation AUC: 0.7069
Current learning rate: 0.0001
Epoch: 2
Train Loss: 0.0318
Train Accuracy: 0.7684
Train Precision: 0.7630
Train Recall: 0.7544
Train F1-score: 0.7587
Train AUC: 0.7679
Validation Loss: 0.0236
Validation Accuracy: 0.8437
Validation Precision: 0.3908
Validation Recall: 0.4595
Validation F1-score: 0.4224
Validation AUC: 0.6789
Current learning rate: 0.0001
Epoch: 3
Train Loss: 0.0226
Train Accuracy: 0.8609
Train Precision: 0.860

  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 8
Train Loss: 0.0074
Train Accuracy: 0.9639
Train Precision: 0.9611
Train Recall: 0.9690
Train F1-score: 0.9650
Train AUC: 0.9637
Validation Loss: 0.0446
Validation Accuracy: 0.9042
Validation Precision: 0.9474
Validation Recall: 0.2432
Validation F1-score: 0.3871
Validation AUC: 0.6207
Current learning rate: 0.0001
Epoch: 9
Train Loss: 0.0078
Train Accuracy: 0.9617
Train Precision: 0.9559
Train Recall: 0.9666
Train F1-score: 0.9612
Train AUC: 0.9618
Validation Loss: 0.0252
Validation Accuracy: 0.9076
Validation Precision: 0.9524
Validation Recall: 0.2703
Validation F1-score: 0.4211
Validation AUC: 0.6342
Current learning rate: 0.0001
Epoch: 10
Train Loss: 0.0057
Train Accuracy: 0.9710
Train Precision: 0.9726
Train Recall: 0.9684
Train F1-score: 0.9705
Train AUC: 0.9710
Validation Loss: 0.0167
Validation Accuracy: 0.9378
Validation Precision: 0.8776
Validation Recall: 0.5811
Validation F1-score: 0.6992
Validation AUC: 0.7848
Current learning rate: 0.0001
Epoch: 11
Train Loss: 0.

  model.load_state_dict(torch.load('best_model.pth'))


Test Loss: 0.0337
Test Accuracy: 0.9430
Test Precision: 0.8667
Test Recall: 0.6500
Test F1-score: 0.7429
Test AUC: 0.8178


In [2]:
model.load_state_dict(torch.load('best_model_weights_only.pth', weights_only=True))
print("Best model loaded. Evaluating on test set:")
test()

FileNotFoundError: [Errno 2] No such file or directory: 'best_model_weights_only.pth'

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.metrics import precision_score, recall_score, roc_auc_score
from PIL import Image
import pydicom

MODEL_PATH = '/kaggle/working/best_model.pth'
DATA_DIR = '/kaggle/input/iaaa-mri-challenge/data'
CSV_PATH = '/kaggle/input/iaaa-mri-challenge/train.csv'

class AttentionMIL(nn.Module):
    def __init__(self, num_classes=1):
        super(AttentionMIL, self).__init__()
        self.L = 512
        self.D = 128
        self.K = 1

        self.feature_extractor = models.resnet34(weights=None)
        self.feature_extractor.fc = nn.Identity()

        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K)
        )

        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        batch_size, num_images, channels, height, width = x.size()
        x = x.view(-1, channels, height, width)
        H = self.feature_extractor(x)
        H = H.view(batch_size, num_images, -1)
        A = self.attention(H)
        A = torch.transpose(A, 2, 1)
        A = nn.functional.softmax(A, dim=2)
        M = torch.bmm(A, H)
        Y_logit = self.classifier(M.view(batch_size, -1))
        Y_hat = torch.sigmoid(Y_logit)
        return Y_logit, Y_hat, A

def dicom_to_array(dicom_path):
    dicom = pydicom.dcmread(dicom_path)
    return dicom.pixel_array

class BrainMRIDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.data_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        patient_id = self.data_frame.iloc[idx, 0]
        label = self.data_frame.iloc[idx, 1]
        patient_dir = os.path.join(self.root_dir, patient_id)
        
        images = []
        for img_name in os.listdir(patient_dir):
            img_path = os.path.join(patient_dir, img_name)
            image = dicom_to_array(img_path)
            image = Image.fromarray(image).convert('RGB')
            if self.transform:
                image = self.transform(image)
            images.append(image)
        
        if len(images) < 20:
            padding = [torch.zeros_like(images[0]) for _ in range(20 - len(images))]
            images.extend(padding)
        elif len(images) > 20:
            images = images[:20]
        
        images = torch.stack(images)
        return images, torch.tensor(label, dtype=torch.float32)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AttentionMIL().to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

dataset = BrainMRIDataset(csv_file=CSV_PATH, root_dir=DATA_DIR, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=False)

y_true = []
y_pred = []

with torch.no_grad():
    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)
        
        _, predictions, _ = model(images)
        predictions = predictions.squeeze().cpu().numpy()
        labels = labels.cpu().numpy()
        
        y_true.extend(labels)
        y_pred.extend(predictions)

y_true = np.array(y_true)
y_pred = np.array(y_pred)

precision = precision_score(y_true, y_pred.round())
recall = recall_score(y_true, y_pred.round())
auc = roc_auc_score(y_true, y_pred)

print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'AUC: {auc:.4f}')