imports

In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torchvision.transforms import ToTensor, Resize
from tqdm import tqdm, trange

from torch.utils.data import Dataset ,DataLoader, WeightedRandomSampler
from torch.utils.data.dataset import Subset
import os
from PIL import Image

from torchvision.transforms import ToTensor

import torch.nn as nn
import torch
class CreatePatches(nn.Module):
    def __init__(
        self, channels=1, embed_dim=768, patch_size=16
    ):
        super().__init__()
        self.patch = nn.Conv2d(
            in_channels=channels,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
    def forward(self, x):
        # Flatten along dim = 2 to maintain channel dimension.
        patches = self.patch(x).flatten(2).transpose(1, 2)
        return patches

class AttentionBlock(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        super().__init__()
        self.pre_norm = nn.LayerNorm(embed_dim, eps=1e-06)
        self.attention = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        self.norm = nn.LayerNorm(embed_dim, eps=1e-06)
        self.MLP = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        x_norm = self.pre_norm(x)
        # MultiheadAttention returns attention output and weights,
        # we need only the outputs, so [0] index.
        x = x + self.attention(x_norm, x_norm, x_norm)[0]
        x = x + self.MLP(self.norm(x))
        return x

class ViT(nn.Module):
    def __init__(
        self, 
        img_size=224,
        in_channels=1,
        patch_size=16,
        embed_dim=768,
        hidden_dim=3072,
        num_heads=12,
        num_layers=12,
        dropout=0.0,
        num_classes=4
    ):
        super().__init__()
        self.patch_size = patch_size
        num_patches = (img_size//patch_size) ** 2
        self.patches = CreatePatches(
            channels=in_channels,
            embed_dim=embed_dim,
            patch_size=patch_size
        )

        # Postional encoding.
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, embed_dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))

        self.attn_layers = nn.ModuleList([])
        for _ in range(num_layers):
            self.attn_layers.append(
                AttentionBlock(embed_dim, hidden_dim, num_heads, dropout)
            )
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(embed_dim, eps=1e-06)
        self.head = nn.Linear(embed_dim, num_classes)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        x = self.patches(x)
        b, n, _ = x.shape
 
        cls_tokens = self.cls_token.expand(b, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding

        x = self.dropout(x)

        for layer in self.attn_layers:
            x = layer(x)
        x = self.ln(x)
        x = x[:, 0]
        return self.head(x)


  from .autonotebook import tqdm as notebook_tqdm


define dataset

In [2]:
class AlzheimerDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the image categories.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # List the categories
        categories = ['Mild Dementia', 'Moderate Dementia', 'Non Demented', 'Very mild Dementia']
        label_mapping = {category: idx for idx, category in enumerate(categories)}

        for category in categories:
            category_path = os.path.join(root_dir, category)
            for img_name in os.listdir(category_path):
                self.image_paths.append(os.path.join(category_path, img_name))
                self.labels.append(label_mapping[category])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('L')  # Convert to grayscale if not already

        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        # Convert label to tensor
        label_tensor = torch.tensor(label, dtype=torch.long)

        return image, label_tensor

In [3]:
# Define transformations for data augmentation
from torch.utils.data import WeightedRandomSampler, random_split
from torchvision import transforms

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Resize((224,224)),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Create an instance of the dataset
dataset = AlzheimerDataset(root_dir='./Data', transform=transform)

# Calculate class weights for balancing
class_sample_count = np.array([len(np.where(np.array(dataset.labels) == t)[0]) for t in np.unique(dataset.labels)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in dataset.labels])

# Define the sizes for train, validation, and test sets
total_size = len(dataset)
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Extract the indices of the train_dataset to get the corresponding weights
train_indices = train_dataset.indices
train_weights = samples_weight[train_indices]

# Create the sampler for the training set
train_sampler = WeightedRandomSampler(train_weights, len(train_weights))

# Number of workers for data loading
num_workers = 0  # Adjust based on your system’s capability

# Create the DataLoaders with the sampler for the training set
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    sampler=train_sampler,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True  # Helps with faster data transfer to CUDA
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)


show data shapes in 1 batch

In [4]:
# Fetch one mini-batch from the DataLoader
for images, labels in train_loader:
    print(f'Batch images shape: {images.shape}')  # Should be [batch_size, 1, height, width]
    print(f'Batch labels shape: {labels.shape}')  # Should be [batch_size]

    # Now iterate through the mini-batch to check each image and label
    for i in range(images.size(0)):  # Loop through the batch
        image_shape = images[i].shape
        label = labels[i]
        print(f'Image {i+1} shape: {image_shape}, Label: {label}')

    # Break after one batch to limit output
    break



Batch images shape: torch.Size([32, 1, 224, 224])
Batch labels shape: torch.Size([32])
Image 1 shape: torch.Size([1, 224, 224]), Label: 1
Image 2 shape: torch.Size([1, 224, 224]), Label: 0
Image 3 shape: torch.Size([1, 224, 224]), Label: 0
Image 4 shape: torch.Size([1, 224, 224]), Label: 2
Image 5 shape: torch.Size([1, 224, 224]), Label: 1
Image 6 shape: torch.Size([1, 224, 224]), Label: 2
Image 7 shape: torch.Size([1, 224, 224]), Label: 1
Image 8 shape: torch.Size([1, 224, 224]), Label: 2
Image 9 shape: torch.Size([1, 224, 224]), Label: 1
Image 10 shape: torch.Size([1, 224, 224]), Label: 1
Image 11 shape: torch.Size([1, 224, 224]), Label: 1
Image 12 shape: torch.Size([1, 224, 224]), Label: 2
Image 13 shape: torch.Size([1, 224, 224]), Label: 0
Image 14 shape: torch.Size([1, 224, 224]), Label: 0
Image 15 shape: torch.Size([1, 224, 224]), Label: 0
Image 16 shape: torch.Size([1, 224, 224]), Label: 2
Image 17 shape: torch.Size([1, 224, 224]), Label: 2
Image 18 shape: torch.Size([1, 224, 22

train

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Hyperparameters
# img_size = 224
# patch_size = 16
# d_model = 768
# nhead = 12
# num_layers = 6
# num_classes = 4  # Adjust based on your classification task
learning_rate = 0.005
# batch_size = 32
# epochs = 10

# Model
vit_model = ViT(
        img_size=224,
        patch_size=16,
        embed_dim=768,
        hidden_dim=3072,
        num_heads=12,
        num_layers=12
    )
model = vit_model
print(model)
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")
rnd_int = torch.randn(1, 1, 224, 224)
output = model(rnd_int)
print(f"Output shape from model: {output.shape}")

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=learning_rate)

# Define directory for model checkpoints
checkpoint_dir = 'ModelCheckpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    vit_model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = vit_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    epoch_loss = running_loss / len(train_dataset)
    
    # Validation phase
    vit_model.eval()
    running_val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = vit_model(inputs)
            loss = criterion(outputs, labels)
            running_val_loss += loss.item() * inputs.size(0)
    epoch_val_loss = running_val_loss / len(val_dataset)

    # Print epoch training and validation losses
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}')

    # Save model checkpoint
    checkpoint_path = os.path.join(checkpoint_dir, f'vit_model_epoch_{epoch+1}.pt')
    torch.save(vit_model.state_dict(), checkpoint_path)
    print(f'Model checkpoint saved at {checkpoint_path}')

print('Finished Training')

ViT(
  (patches): CreatePatches(
    (patch): Conv2d(1, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (attn_layers): ModuleList(
    (0-11): 12 x AttentionBlock(
      (pre_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (MLP): Sequential(
        (0): Linear(in_features=768, out_features=3072, bias=True)
        (1): GELU(approximate='none')
        (2): Dropout(p=0.0, inplace=False)
        (3): Linear(in_features=3072, out_features=768, bias=True)
        (4): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (ln): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (head): Linear(in_features=768, out_features=4, bias=True)
)
85,408,516 total parameters.
85,408,516 training parameters.
Output shape from m

KeyboardInterrupt: 