In [2]:
pip install torch torchvision einops


Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m777.9 kB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.0.2.54-

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt

# Define data transformations
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
batch_size = 32

# Load datasets
train_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/Colorectal Images/ColorectalImage2/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

val_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/Colorectal Images/ColorectalImage2/val', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

test_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/Colorectal Images/ColorectalImage2/test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


In [8]:
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

class MultiScaleVisionTransformer(nn.Module):
    def __init__(self, img_size=224, num_classes=9, patch_size_s=16, patch_size_l=32, dim=128, depth=6, heads=4, mlp_dim=256, dropout=0.1):
        super().__init__()

        # Define the two patch sizes and number of patches
        self.patch_size_s = patch_size_s
        self.patch_size_l = patch_size_l
        self.num_patches_s = (img_size // patch_size_s) ** 2
        self.num_patches_l = (img_size // patch_size_l) ** 2

        # Linear projection layers for both branches
        self.to_patch_embedding_s = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size_s, p2=patch_size_s),
            nn.Linear(patch_size_s * patch_size_s * 3, dim),
        )

        self.to_patch_embedding_l = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size_l, p2=patch_size_l),
            nn.Linear(patch_size_l * patch_size_l * 3, dim),
        )

        # Class tokens
        self.cls_token_s = nn.Parameter(torch.randn(1, 1, dim))
        self.cls_token_l = nn.Parameter(torch.randn(1, 1, dim))

        # Positional embeddings
        self.pos_embedding_s = nn.Parameter(torch.randn(1, self.num_patches_s + 1, dim))
        self.pos_embedding_l = nn.Parameter(torch.randn(1, self.num_patches_l + 1, dim))

        # Dropout layers
        self.dropout = nn.Dropout(dropout)

        # Transformer encoders for each branch
        self.transformer_s = nn.TransformerEncoder(nn.TransformerEncoderLayer(dim, heads, mlp_dim, dropout), depth)
        self.transformer_l = nn.TransformerEncoder(nn.TransformerEncoderLayer(dim, heads, mlp_dim, dropout), depth)

        # Linear layers to match dimensions
        self.proj_s = nn.Linear(dim, dim)
        self.proj_l = nn.Linear(dim, dim)

        # Cross-Attention layer
        self.cross_attention = nn.TransformerEncoderLayer(dim, heads, mlp_dim, dropout)

        # MLP heads
        self.mlp_head_s = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))
        self.mlp_head_l = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))

    def forward(self, img):
        # Small patches branch
        b, c, h, w = img.shape
        patches_s = self.to_patch_embedding_s(img)
        cls_tokens_s = repeat(self.cls_token_s, '() n d -> b n d', b=b)
        x_s = torch.cat((cls_tokens_s, patches_s), dim=1)
        x_s += self.pos_embedding_s[:, :(x_s.size(1))]
        x_s = self.dropout(x_s)
        x_s = self.transformer_s(x_s)

        # Large patches branch
        patches_l = self.to_patch_embedding_l(img)
        cls_tokens_l = repeat(self.cls_token_l, '() n d -> b n d', b=b)
        x_l = torch.cat((cls_tokens_l, patches_l), dim=1)
        x_l += self.pos_embedding_l[:, :(x_l.size(1))]
        x_l = self.dropout(x_l)
        x_l = self.transformer_l(x_l)

        # Ensure same number of patches using pooling
        if x_s.size(1) > x_l.size(1):
            x_s = F.adaptive_avg_pool1d(x_s.transpose(1, 2), output_size=(x_l.size(1))).transpose(1, 2)
        else:
            x_l = F.adaptive_avg_pool1d(x_l.transpose(1, 2), output_size=(x_s.size(1))).transpose(1, 2)

        # Project both branches to the same dimension
        x_s_proj = self.proj_s(x_s)
        x_l_proj = self.proj_l(x_l)

        # Cross-attention
        x_s = self.cross_attention(x_s_proj + x_l_proj)
        x_l = self.cross_attention(x_l_proj + x_s_proj)

        # Classification
        cls_s = x_s[:, 0]
        cls_l = x_l[:, 0]
        logits_s = self.mlp_head_s(cls_s)
        logits_l = self.mlp_head_l(cls_l)

        return logits_s + logits_l


In [6]:
# Define the number of classes
num_classes = 9

# Instantiate the model
model = MultiScaleVisionTransformer(num_classes=num_classes).cuda()

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Learning rate scheduler
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

# Training loop with plotting
num_epochs = 20
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    for images, labels in train_loader:
        images, labels = images.cuda(), labels.cuda()
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        train_total += labels.size(0)
        train_correct += predicted.eq(labels).sum().item()

    train_loss /= train_total
    train_accuracy = 100. * train_correct / train_total

    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()

    val_loss /= val_total
    val_accuracy = 100. * val_correct / val_total

    scheduler.step()

    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)

    print(f"Epoch {epoch + 1}/{num_epochs}, "
          f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, "
          f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%")


  self.pid = os.fork()


KeyboardInterrupt: 

In [None]:
# Plotting accuracy and loss
epochs_range = range(1, num_epochs + 1)
plt.figure(figsize=(14, 5))

plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_accuracies, label='Training Accuracy')
plt.plot(epochs_range, val_accuracies, label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_losses, label='Training Loss')
plt.plot(epochs_range, val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

plt.show()
