In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

# Split the training dataset into training and validation sets
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


In [2]:
def select_batch_and_samples(data_loader, batch_idx, sample_indices):
    batches = [batch for batch_idx_, batch in enumerate(data_loader) if batch_idx_ == batch_idx]
    if not batches:
        raise ValueError("Batch index out of range")
    batch = batches[0]
    samples = batch[0][sample_indices]
    return samples

def apply_patching(images, patch_size=4):
    batch_size, channels, height, width = images.shape
    patches = images.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    patches = patches.contiguous().view(batch_size, channels, -1, patch_size, patch_size)
    patches = patches.permute(0, 2, 1, 3, 4)
    patches = patches.contiguous().view(batch_size * patches.size(1), channels, patch_size, patch_size)
    return patches

def apply_masking(patches, mask_percent=0.75):
    batch_size = patches.shape[0]
    mask = torch.rand(batch_size) < mask_percent
    masked_patches = patches.clone()
    masked_patches[mask] = 0
    return masked_patches, mask


In [2]:
import torch.nn as nn

class PatchEmbedding(nn.Module):
    def __init__(self, patch_size=4, embedding_dim=64):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.embedding_dim = embedding_dim
        self.patch_embedding = nn.Linear(patch_size * patch_size, embedding_dim)
        self.positional_embedding = nn.Parameter(torch.randn(28 // patch_size * 28 // patch_size, embedding_dim))

    def forward(self, patches):
        batch_size, _, h, w = patches.shape
        patches = patches.view(batch_size, -1, h * w)
        embedded_patches = self.patch_embedding(patches)
        embedded_patches += self.positional_embedding.unsqueeze(0)
        return embedded_patches

class MaskedAutoencoder(nn.Module):
    def __init__(self, embedding_dim=64):
        super(MaskedAutoencoder, self).__init__()
        self.patch_embedding = PatchEmbedding(embedding_dim=embedding_dim)
        self.encoder = nn.Sequential(
            nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=4),
            nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=4)
        )
        self.decoder = nn.Sequential(
            nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=4),
            nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=4)
        )
        self.reconstruction = nn.Linear(embedding_dim, 4 * 4)

    def forward(self, patches, mask):
        embedded_patches = self.patch_embedding(patches)
        encoded_patches = self.encoder(embedded_patches)
        decoded_patches = self.decoder(encoded_patches)
        reconstructed_patches = self.reconstruction(decoded_patches)
        return reconstructed_patches


In [3]:
import matplotlib.pyplot as plt
import torch
import torchvision
torch.__version__
from torch import nn
from torchvision import transforms

In [4]:
train_dir = '/content/MICCAI-Educational-Challenge-2024/train'
test_dir = '/content/MICCAI-Educational-Challenge-2024/test'
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
NUM_WORKERS = os.cpu_count()
def create_dataloaders(
 train_dir: str,
 test_dir: str,
 transform: transforms.Compose,
 batch_size: int,
 num_workers: int=NUM_WORKERS
 ):
 # Use ImageFolder to create dataset(s)
 train_data = datasets.ImageFolder(train_dir, transform=transform)
 test_data = datasets.ImageFolder(test_dir, transform=transform)
 # Get class names
 class_names = train_data.classes
 # Turn images into train and test data loaders
 train_dataloader = DataLoader(
 train_data,
 batch_size=batch_size,
 shuffle=True,
 num_workers=num_workers,
 pin_memory=True,
 )
 test_dataloader = DataLoader(
 test_data,
 batch_size=batch_size,
 shuffle=False,
 num_workers=num_workers,
 pin_memory=True,
 )
 return train_dataloader, test_dataloader, class_names

In [5]:
IMG_SIZE = 240 

manual_tranforms = transforms.Compose([
    transforms.Resize((IMG_SIZE , IMG_SIZE)), 
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),

])
print(f"Manually create Tranform:{manual_tranforms}")

Manually create Tranform:Compose(
    Resize(size=(240, 240), interpolation=bilinear, max_size=None, antialias=None)
    ToTensor()
    RandomHorizontalFlip(p=0.5)
    RandomVerticalFlip(p=0.5)
)


In [7]:
import torch
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np

# Define a masked transformation function
def masked_transform(image, mask, angle):
    # Convert image and mask to numpy arrays
    image_np = np.array(image)
    mask_np = np.array(mask)

    # Apply rotation to the entire image
    rotation_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.functional.rotate(angle),
        transforms.ToTensor()
    ])

    # Create a transformed version of the image
    transformed_image = rotation_transform(image)

    # Convert the transformed image back to a numpy array
    transformed_image_np = transformed_image.permute(1, 2, 0).numpy()

    # Apply the mask to combine the original and transformed images
    masked_image_np = image_np * (1 - mask_np) + transformed_image_np * mask_np

    return torch.from_numpy(masked_image_np).permute(2, 0, 1)

# Load MNIST dataset
mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())

# Select an image from the dataset
image, label = mnist_dataset[0]

# Create a simple mask (e.g., mask the upper half of the image)
mask = torch.zeros_like(image)
mask[0, :14, :] = 1  # Mask the upper half

# Apply masked transformation
angle = 30  # Rotation angle
transformed_image = masked_transform(image, mask, angle)

# Plot the original image, mask, and transformed image
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(image.squeeze(), cmap='gray')
axs[0].set_title('Original Image')
axs[1].imshow(mask.squeeze(), cmap='gray')
axs[1].set_title('Mask')
axs[2].imshow(transformed_image.squeeze(), cmap='gray')
axs[2].set_title('Transformed Image')
plt.show()


TypeError: rotate() missing 1 required positional argument: 'angle'