In [1]:
# Configuration Parameters
BATCH_SIZE = 64                     # Number of images processed in one training batch
IMG_SIZE = (128, 128)               # Size of input images (height, width)
PATCH_SIZE = (8, 8)                 # Size of each patch extracted from images

# Calculate the number of patches extracted from the entire image
NUM_PATCH = (IMG_SIZE[0] // PATCH_SIZE[0]) ** 2

# Downsampling factor (e.g., after feature extraction or pooling)
T = NUM_PATCH // 4

# Transformer Configuration
NUM_H = 4                           # Number of attention heads in the Transformer
NUM_TR = 11                         # Number of Transformer blocks/layers

In [2]:
# For this notebook to run with updated APIs, we need torch 1.12+ and torchvision 0.13+
try:
    import torch
    import torchvision
    assert int(torch.__version__.split(".")[1]) >= 12 or int(torch.__version__.split(".")[0]) == 2, "torch version should be 1.12+"
    assert int(torchvision.__version__.split(".")[1]) >= 13, "torchvision version should be 0.13+"
    print(f"torch version: {torch.__version__}")
    print(f"torchvision version: {torchvision.__version__}")
except:
    print(f"[INFO] torch/torchvision versions not as required, installing nightly versions.")
    !pip3 install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
    import torch
    import torchvision
    print(f"torch version: {torch.__version__}")
    print(f"torchvision version: {torchvision.__version__}")

torch version: 2.1.1+cu121
torchvision version: 0.16.1+cu121


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Continue with regular imports
import matplotlib.pyplot as plt
import torch
import torchvision

from torch import nn
from torchvision import transforms

from torchinfo import summary
import data_setup, engine
from helper_functions import download_data, set_seeds, plot_loss_curves

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [5]:
import requests
import zipfile
from pathlib import Path

# Setup path to a data folder
data_path = Path("../")
image_path = data_path / "OCT2017"


In [6]:
# Setup directory paths to train and test images
train_dir = image_path / "train"
test_dir = image_path / "test"
train_dir, test_dir

(PosixPath('../OCT2017/train'), PosixPath('../OCT2017/test'))

**************************************************************************************

In [7]:
my_transform = transforms.Compose([
                                    transforms.Resize((IMG_SIZE)),
                                    transforms.ToTensor()
])

In [8]:
# Original img_to_patch function
def generate_vertical_permutation(H_p, W_p):
    """
    Generates a permutation list that rearranges patches in vertical order.
    Inputs:
        H_p - Number of patches along the height (rows)
        W_p - Number of patches along the width (columns)
    Returns:
        perm - List of indices representing the vertical order
    """
    perm = []
    for j in range(W_p):  # For each column
        for i in range(H_p):  # For each row
            idx = i * W_p + j  # Calculate index in the flattened array
            perm.append(idx)
    return perm


def img_to_patch_vertical(x, patch_size, flatten_channels=True):
    """
    Extracts patches from images in vertical order, column by column.

    Inputs:
        x - torch.Tensor representing the image of shape [B, C, H, W]
        patch_size - Number of pixels per dimension of the patches (integer)
        flatten_channels - If True, the patches will be returned in a flattened format
                           as a feature vector instead of an image grid.

    Returns:
        x - Tensor of extracted patches in the desired vertical order.
    """
    B, C, H, W = x.shape
    H_p = H // patch_size
    W_p = W // patch_size

    # Ensure H and W are divisible by patch_size
    assert H % patch_size == 0 and W % patch_size == 0, "H and W must be divisible by patch_size"

    # Reshape and permute to extract patches
    x = x.reshape(B, C, H_p, patch_size, W_p, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5)  # [B, H_p, W_p, C, p_H, p_W]
    x = x.reshape(B, H_p * W_p, C, patch_size, patch_size)  # [B, H_p * W_p, C, p_H, p_W]

    # Generate permutation for vertical ordering
    perm = generate_vertical_permutation(H_p, W_p)
    perm_tensor = torch.tensor(perm, device=x.device)

    # Rearrange patches according to the vertical order
    x = x[:, perm_tensor, ...]  # Reorder patches along dimension 1

    if flatten_channels:
        # Flatten the C, p_H, p_W dimensions into a single feature vector per patch
        x = x.flatten(2, 4)  # [B, Num_Patches, C * p_H * p_W]

    return x

In [9]:
class AttentionBlock(nn.Module):

    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of input and attention feature vectors
            hidden_dim - Dimensionality of hidden layer in feed-forward network
                         (usually 2-4x larger than embed_dim)
            num_heads - Number of heads to use in the Multi-Head Attention block
            dropout - Amount of dropout to apply in the feed-forward network
        """
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads,
                                          dropout=dropout)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )



    def forward(self, x):
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        return x

In [10]:
class Transpose(nn.Module):
    def __init__(self, dim1, dim2):
        super().__init__()
        self.dim1 = dim1
        self.dim2 = dim2

    def forward(self, x):
        return x.transpose(self.dim1, self.dim2)


In [11]:
class VisionTransformer(nn.Module):

    def __init__(self, embed_dim, hidden_dim,n_feature, num_channels, num_heads, num_layers, num_classes, patch_size, num_patches, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of the input feature vectors to the Transformer
            hidden_dim - Dimensionality of the hidden layer in the feed-forward networks
                         within the Transformer
            num_channels - Number of channels of the input (3 for RGB)
            num_heads - Number of heads to use in the Multi-Head Attention block
            num_layers - Number of layers to use in the Transformer
            num_classes - Number of classes to predict
            patch_size - Number of pixels that the patches have per dimension
            num_patches - Maximum number of patches an image can have
            dropout - Amount of dropout to apply in the feed-forward network and
                      on the input encoding
        """
        super().__init__()

        self.patch_size = patch_size
        self.input_layer = nn.Sequential(
            nn.Conv3d(3 , n_feature*3 , kernel_size = 3 , padding=1 , stride=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(2,1,1) , stride=(2,1,1) , padding=(0,0,0)),

            nn.Conv3d(n_feature*3 , n_feature*6 , 3 , padding=1 , stride=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(2,1,1) , stride=(2,1,1) , padding=(0,0,0)),
            nn.Conv3d(n_feature*6 , n_feature*8 , 3 , padding=1 , stride=1),

            Transpose(2,1),
            nn.Flatten(2,-1),
            nn.Linear(n_feature*8*PATCH_SIZE[0]*PATCH_SIZE[0] ,embed_dim )


        )
        self.transformer = nn.Sequential(*[AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers)])
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )
        self.dropout = nn.Dropout(dropout)

        # Parameters/Embeddings
        self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1,1+int(NUM_PATCH / 4),embed_dim))


    def forward(self, x):
        # Preprocess input
        x = img_to_patch_vertical(x, self.patch_size , flatten_channels=False).transpose(1,2)
        B, _, _ , _ , _= x.shape
        T = int(NUM_PATCH / 4)
        x = self.input_layer(x)
        # Add CLS token and positional encoding
        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embedding[:,:T+1]

        # Apply Transforrmer
        x = self.dropout(x)
        x = x.transpose(0, 1)
        x = self.transformer(x)

        # Perform classification prediction
        cls = x[0]
        out = self.mlp_head(cls)
        return out

In [12]:
AmirModel = VisionTransformer(embed_dim = 256 ,
                            hidden_dim = 512,
                            n_feature=8 ,
                            num_channels = 3,
                            num_heads = NUM_H ,
                            num_layers = NUM_TR ,
                            num_classes=4,
                            patch_size=PATCH_SIZE[0] ,
                            num_patches=NUM_PATCH ,
                            dropout=0.1 ).to(device)

**************************************************************************************

In [37]:
# # Print a summary using torchinfo (uncomment for actual output)
summary(model=AmirModel,
        input_size=(32, 3, 128, 128), # (batch_size, color_channels, height, width)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                       Input Shape          Output Shape         Param #              Trainable
VisionTransformer (VisionTransformer)         [32, 3, 128, 128]    [32, 4]              16,896               True
├─Sequential (input_layer)                    [32, 3, 256, 8, 8]   [32, 64, 256]        --                   True
│    └─Conv3d (0)                             [32, 3, 256, 8, 8]   [32, 24, 256, 8, 8]  1,968                True
│    └─ReLU (1)                               [32, 24, 256, 8, 8]  [32, 24, 256, 8, 8]  --                   --
│    └─MaxPool3d (2)                          [32, 24, 256, 8, 8]  [32, 24, 128, 8, 8]  --                   --
│    └─Conv3d (3)                             [32, 24, 128, 8, 8]  [32, 48, 128, 8, 8]  31,152               True
│    └─ReLU (4)                               [32, 48, 128, 8, 8]  [32, 48, 128, 8, 8]  --                   --
│    └─MaxPool3d (5)                          [32, 48, 128, 8, 8]  [32, 48, 64, 8, 8]   -

In [14]:

train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(train_dir=train_dir,
                                                                                test_dir=test_dir,
                                                                                transform=my_transform,
                                                                                batch_size=BATCH_SIZE,
                                                                                num_workers = 8) # Could increase if we had more samples, such as here: https://arxiv.org/abs/2205.01580 (there are other improvements there too...)


In [15]:
class_names

['CNV', 'DME', 'DRUSEN', 'NORMAL']

In [16]:
print(f"train_dir: {train_dir}")
print(f"train_dir: {test_dir}")
print(f"len train_dataloader: {len(train_dataloader)}")
print(f"len test_dataloader: {len(test_dataloader)}")
print(f"device: {device}")

train_dir: ../OCT2017/train
train_dir: ../OCT2017/test
len train_dataloader: 1305
len test_dataloader: 16
device: cuda


In [17]:
# Create optimizer and loss function
optimizer = torch.optim.SGD(AmirModel.parameters() , 0.01)
loss_fn = torch.nn.CrossEntropyLoss()

In [27]:
# Compile the model
torch.set_float32_matmul_precision('high')
AmirModel = torch.compile(AmirModel)

In [19]:

part = "2"
# Train the classifier head of the pretrained ViT feature extractor model
set_seeds()
x = engine.train(model=AmirModel,
                                      train_dataloader=train_dataloader,
                                      test_dataloader=test_dataloader,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      epochs=20,
                                      Pivot=90,
                                      Part=part,
                                      save_weights=False,
                                      device=device)

  0%|                                                                                                                                                                                                                | 0/20 [00:16<?, ?it/s]


KeyboardInterrupt: 

In [None]:
+name = f"OPT = AdamW ,Compile = True , Epo = 10 ,lr = 0.001 ,least_time =253 ,IMG = 128 ,patch = 8 ,BATCH = 64 ,using_pretrain = True ,Part = 8 , tr_acc = 99.45 , te_acc = 98.93"
name

## Save the results

In [None]:
import numpy as np
np.save("./drive/MyDrive/TrackExp/" + name , x)

## Load Results

In [None]:
load_name = f"OPT = AdamW ,Compile = True , Epo = 20 ,lr = 0.001 ,least_time = 261 ,IMG = 128 ,patch = 8 ,BATCH = 64 ,using_pretrain = False ,Part = 0"

In [None]:
loaded_results = np.load("./TrackExp/" + name + ".npy" , allow_pickle='TRUE')
loaded_results

In [None]:
import copy
model_copy = copy.deepcopy(AmirModel)

## Save model

In [None]:
# # Save the model
from going_modular.going_modular import utils

utils.save_model(model=AmirModel,
                 target_dir="./TrackExp",
                 model_name=name2 + ".pth")

## Load model weights

In [22]:
name = "TrackExp/BestWeight_Part_19_Epoch_2_TrainLoss_0.01639417985006382_TestLoss_0.027976555515579093TrainAcc_0.9941451149425288_Testacc_0.9951171875.pth"
AmirModel.load_state_dict(torch.load(name))

<All keys matched successfully>

# Save Optimizer

In [25]:
torch.save(optimizer.state_dict(),"./TrackOptim/SGDOptimizer")

# load Optimizer


In [22]:
optimizer.load_state_dict(torch.load("./TrackOptim/SGDOptimizer"))