In [9]:
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import nibabel as nib
from PIL import Image
import os

In [2]:
class FocalLoss(torch.nn.Module):
    def __init__(self, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
        pt = torch.exp(-BCE_loss)
        F_loss = (1 - pt) ** self.gamma * BCE_loss
        if self.reduction == 'mean':
            return F_loss.mean()
        elif self.reduction == 'sum':
            return F_loss.sum()
        else:
            return F_loss


In [3]:
class MRIPatchDataset(Dataset):
    def __init__(self, image_dir, patch_size=32, transform=None, K=2):
        self.image_dir = image_dir
        self.patch_size = patch_size
        self.transform = transform
        self.K = K
        self.image_files = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.png')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img = Image.open(self.image_files[idx]).convert('RGB')
        w, h = img.size
        patches = []
        for _ in range(self.K):
            left = np.random.randint(0, w - self.patch_size)
            top = np.random.randint(0, h - self.patch_size)
            patch = img.crop((left, top, left + self.patch_size, top + self.patch_size))
            if self.transform:
                patch = self.transform(patch)
            patches.append(patch)
        return patches, idx



In [4]:
def create_pairs(dataset):
    pairs = []
    labels = []
    for i in range(len(dataset)):
        patches, idx = dataset[i]
        for k in range(len(patches)):
            for j in range(k + 1, len(patches)):
                # Positive pair
                pairs.append((patches[k], patches[j]))
                labels.append(1.0)
                # Negative pair
                neg_idx = np.random.randint(0, len(dataset))
                while neg_idx == idx:
                    neg_idx = np.random.randint(0, len(dataset))
                neg_patches, _ = dataset[neg_idx]
                neg_patch = neg_patches[np.random.randint(0, len(neg_patches))]
                pairs.append((patches[k], neg_patch))
                labels.append(0.0)
    return pairs, labels


In [5]:
class Conv4(torch.nn.Module):
    def __init__(self):
        super(Conv4, self).__init__()
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(8),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(kernel_size=2, stride=2)
        )
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(16),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(kernel_size=2, stride=2)
        )
        self.layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(kernel_size=2, stride=2)
        )
        self.layer4 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            torch.nn.AdaptiveAvgPool2d(1)
        )

    def forward(self, x):
        h = self.layer1(x)
        h = self.layer2(h)
        h = self.layer3(h)
        h = self.layer4(h)
        h = torch.flatten(h, 1)
        return h


In [6]:
class RelationalReasoning(torch.nn.Module):
    def __init__(self, backbone, feature_size=64):
        super(RelationalReasoning, self).__init__()
        self.backbone = backbone
        self.relation_head = torch.nn.Sequential(
            torch.nn.Linear(feature_size * 2, 256),
            torch.nn.BatchNorm1d(256),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(256, 1)
        )

    def forward(self, x1, x2):
        features1 = self.backbone(x1)
        features2 = self.backbone(x2)
        combined_features = torch.cat([features1, features2], dim=1)
        score = self.relation_head(combined_features)
        return score


In [7]:
def train_model(model, pairs, labels, epochs=10, batch_size=64, learning_rate=0.001):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = FocalLoss()
    dataset = list(zip(pairs, labels))
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for i, (pair, label) in enumerate(dataloader):
            img1, img2 = zip(*pair)
            img1 = torch.stack(img1)
            img2 = torch.stack(img2)
            label = torch.tensor(label).float()

            optimizer.zero_grad()
            outputs = model(img1, img2).squeeze()
            loss = criterion(outputs, label)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            predicted = (torch.sigmoid(outputs) > 0.5).float()
            total += label.size(0)
            correct += (predicted == label).sum().item()

        accuracy = 100 * correct / total
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(dataloader):.4f}, Accuracy: {accuracy:.2f}%")


In [8]:
# Data Augmentation
transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ToTensor()
])

In [None]:
# Main script
image_dir = 'path_to_mri_images'  # Update with the actual path
dataset = MRIPatchDataset(image_dir, transform=transform, K=2)
pairs, labels = create_pairs(dataset)