In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import Callable

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

class SelfAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, heads, similarity_fun=None):
        super().__init__()
        self.query = nn.Linear(in_dim, out_dim)
        self.key = nn.Linear(in_dim, out_dim)
        self.value = nn.Linear(in_dim, out_dim)
        self.heads = heads
        self.out_dim_per_head = out_dim // heads
        self.output = nn.Linear(out_dim, in_dim)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        q = self.query(x).view(batch_size, seq_len, self.heads, self.out_dim_per_head).transpose(1, 2)
        k = self.key(x).view(batch_size, seq_len, self.heads, self.out_dim_per_head).transpose(1, 2)
        v = self.value(x).view(batch_size, seq_len, self.heads, self.out_dim_per_head).transpose(1, 2)
        
        # Calculate the attention scores
        attention_scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.out_dim_per_head, dtype=torch.float32))
        attention = torch.softmax(attention_scores, dim=-1)
        
        # Apply attention to the values
        out = torch.matmul(attention, v)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        out = self.output(out)
        
        return self.gamma * out + x


In [3]:
def similarity_fun(Q: torch.tensor, K: torch.tensor):
    # print("Q", Q.shape)
    # print("K.T", K.transpose(1, 2).shape)
    return torch.bmm(Q, K.transpose(1, 2))/torch.sqrt(torch.tensor(K.shape[-1]))

In [4]:
# Assuming input shape (batch_size, n_patches, C * patch_size * patch_size)
class AddPositionEmbedding(nn.Module):
    def __init__(self, num_patches, emb_dim) -> None:
        super().__init__()
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, emb_dim))  # +1 for class token

    def forward(self, x):
        return x + self.pos_embedding


In [5]:
class MlpLayer(nn.Module):
    def __init__(self, input_dim, mlp_dim, dropout=0.5) -> None:
        super().__init__()
        self.layer1 = nn.Linear(input_dim, mlp_dim)
        self.gelu = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.layer2 = nn.Linear(mlp_dim, input_dim)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.gelu(self.layer1(x))
        x = self.dropout1(x)
        x = self.layer2(x)
        x = self.dropout2(x)
        return x

In [6]:
class Encoder(nn.Module):
    def __init__(self, input_dim, patch_size, att_dim, head, mlp_dim, num_patches, dropout=0.5) -> None:
        super().__init__()
        self.ln1 = nn.LayerNorm(input_dim)
        # Adjust the SelfAttentionLayer initialization if needed
        self.att = SelfAttentionLayer(input_dim, att_dim, head, similarity_fun)
        self.ln2 = nn.LayerNorm(att_dim)
        self.mlp = MlpLayer(input_dim, mlp_dim, dropout)

    def forward(self, x):
        # Apply the first LayerNorm
        x = self.ln1(x)
        # Save the input for the skip connection
        x_skip = x
        # Apply the self-attention
        x = self.att(x)
        # Add the skip connection (residual)
        x = x + x_skip

        # Apply the second LayerNorm
        x = self.ln2(x)
        # Save the input for the second skip connection
        x_skip = x
        # Apply the MLP layer
        x = self.mlp(x)
        # Add the second skip connection (residual)
        x = x + x_skip
        
        return x


In [7]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size, patch_size, num_channels, num_classes, emb_dim, num_heads, mlp_dim, num_layers, dropout_rate=0.5):
        super().__init__()
        num_patches = (img_size // patch_size) ** 2
        self.patch_size = patch_size
        self.emb_dim = emb_dim
        self.flatten_dim = patch_size * patch_size * num_channels
        self.linear_proj = nn.Linear(self.flatten_dim, emb_dim)
        self.class_token = nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = AddPositionEmbedding(num_patches, emb_dim)

        self.layers = nn.ModuleList([
            Encoder(input_dim=emb_dim, patch_size=patch_size, att_dim=emb_dim, head=num_heads, mlp_dim=mlp_dim, num_patches=num_patches, dropout=dropout_rate)
            for _ in range(num_layers)
        ])
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(emb_dim),
            nn.Linear(emb_dim, num_classes)
        )

    def forward(self, x):
        # Unfold patches
        x = x.unfold(2, self.patch_size, self.patch_size) \
            .unfold(3, self.patch_size, self.patch_size)
        # Reshape: [Batch, Channels, Patch_height, Patch_width, Num_patches_height, Num_patches_width]
        x = x.permute(0, 4, 5, 1, 2, 3)
        # Flatten patches
        x = x.contiguous().view(x.size(0), -1, self.flatten_dim) # -1 here automatically calculates the correct number of patches
        x = self.linear_proj(x)

        # Add class token and position embeddings
        class_tokens = self.class_token.expand(x.size(0), -1, -1)
        x = torch.cat((class_tokens, x), dim=1)
        x = self.pos_embedding(x)

        # Process through the layers
        for layer in self.layers:
            x = layer(x)

        # Classifier token
        x = x[:, 0]
        x = self.mlp_head(x)
        return x



In [45]:
# Assuming VisionTransformer class is defined as above

# Define transformations and load CIFAR-10 dataset
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),  # Resize images to fit the model's expected input size
#     transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
# ])


# transform = transforms.Compose([
#     transforms.RandomCrop(32, padding=4),  # Apply random crops
#     transforms.RandomHorizontalFlip(),  # Horizontal flipping
#     transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# ])

from torchvision import transforms

# Enhanced data transformations with Color Jitter and Random Erasing
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # Random crop with padding
    transforms.RandomHorizontalFlip(),  # Horizontal flip
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),  # Color jitter
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False)  # Random erasing
])

# For validation, we usually keep it simple
transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


train_dataset = datasets.CIFAR10(root='./', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./', train=False, download=True, transform=transform_val)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)



Files already downloaded and verified
Files already downloaded and verified


In [9]:
# # Initialize the model
# model = VisionTransformer(img_size=224, patch_size=16, num_channels=3, num_classes=10, emb_dim=768, num_heads=12, mlp_dim=3072, num_layers=1, dropout_rate=0.5)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print(device)
# model = model.to(device)

# # Loss function and optimizer
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)

cuda


In [46]:
import torch.optim.lr_scheduler as lr_scheduler
# Adjust model parameters for CIFAR-10
model = VisionTransformer(
    img_size=32,  # Adjusted image size
    patch_size=4,  # Smaller patches to match the smaller image size
    num_channels=3,
    num_classes=10,
    emb_dim=256,  # Reduced dimensionality
    num_heads=8,  # Fewer heads to match the reduced dimensionality
    mlp_dim=512,  # Smaller MLP size
    num_layers=12,  # Keep a single layer for simplicity
    dropout_rate=0.1  # Adjusted dropout rate
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Adjust the optimizer and learning rate
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Introduce learning rate scheduling
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

# Loss function remains unchanged
criterion = nn.CrossEntropyLoss()

# Print the device being used
print(device)

cuda


In [47]:
import os

def save_checkpoint(state, filename="model_checkpoint.tar"):
    os.makedirs(os.path.dirname(filename), exist_ok=True)  # Ensure directory exists
    torch.save(state, filename)

def load_checkpoint(checkpoint_path, model, optimizer, scheduler=None):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if scheduler is not None and 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    return checkpoint.get('epoch', -1), checkpoint.get('best_accuracy', 0.0)

def train_and_validate(model, train_loader, test_loader, optimizer, criterion, device, scheduler=None, num_epochs=5, checkpoint_path=None, filename="checkpoints/model_checkpoint.tar"):
    start_epoch = 0
    best_accuracy = 0.0

    if checkpoint_path is not None and os.path.isfile(checkpoint_path):
        start_epoch, best_accuracy = load_checkpoint(checkpoint_path, model, optimizer, scheduler)
        print(f"Loaded checkpoint '{checkpoint_path}' (epoch {start_epoch}), best accuracy: {best_accuracy}%")
        start_epoch += 1  # Continue from next epoch

    for epoch in range(start_epoch, start_epoch + num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{start_epoch + num_epochs}', leave=False):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Epoch [{epoch+1}/{num_epochs + start_epoch}], Loss: {running_loss/len(train_loader):.4f}')
        
        # Update the learning rate scheduler after each epoch
        if scheduler is not None:
            scheduler.step()

        # Validation loop
        model.eval()
        total = 0
        correct = 0
        with torch.no_grad():
            for images, labels in tqdm(test_loader, desc='Validating', leave=False):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        current_accuracy = 100 * correct / total
        print(f'Validation Accuracy: {current_accuracy:.2f}%')

        # Save checkpoint if current accuracy is the best
        if current_accuracy > best_accuracy:
            print("Saving new best model")
            best_accuracy = current_accuracy
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_accuracy': best_accuracy
            }
            if scheduler is not None:
                checkpoint['scheduler_state_dict'] = scheduler.state_dict()
            save_checkpoint(checkpoint, filename=filename)


In [48]:
filepath = "checkpoints/transformer/vit_layer_12_lr_0_001_op_adam_sch_Cosin.tar"
train_and_validate(model, train_loader, test_loader, optimizer, criterion, device, scheduler, num_epochs=50, filename=filepath)

                                                             

Epoch [1/50], Loss: 2.3327


                                                             

Validation Accuracy: 10.00%
Saving new best model


                                                             

Epoch [2/50], Loss: 2.3104


                                                             

Validation Accuracy: 10.00%


                                                             

Epoch [3/50], Loss: 2.3069


                                                             

Validation Accuracy: 10.00%


                                                             

Epoch [4/50], Loss: 2.3050


                                                             

Validation Accuracy: 10.00%


                                                             

Epoch [5/50], Loss: 2.3042


                                                             

Validation Accuracy: 10.00%


                                                             

Epoch [6/50], Loss: 2.3041


                                                             

Validation Accuracy: 10.00%


                                                             

Epoch [7/50], Loss: 2.3038


                                                             

Validation Accuracy: 10.00%


                                                             

KeyboardInterrupt: 

In [37]:

train_and_validate(model, train_loader, test_loader, optimizer, criterion, device, scheduler, num_epochs=10, checkpoint_path=filepath, filename=filepath)

Loaded checkpoint 'checkpoints/transformer/vit_layer_1_lr_0_001_op_adam.tar' (epoch 37), best accuracy: 48.36%


                                                               

Epoch [39/48], Loss: 1.4326


                                                             

Validation Accuracy: 48.29%


                                                               

Epoch [40/48], Loss: 1.4305


                                                             

Validation Accuracy: 47.86%


                                                               

Epoch [41/48], Loss: 1.4279


                                                             

Validation Accuracy: 48.34%


                                                              

Epoch [42/48], Loss: 1.4299


                                                             

Validation Accuracy: 48.62%
Saving new best model


                                                               

Epoch [43/48], Loss: 1.4272


                                                             

Validation Accuracy: 48.58%


                                                               

Epoch [44/48], Loss: 1.4302


                                                             

Validation Accuracy: 48.45%


                                                               

Epoch [45/48], Loss: 1.4319


                                                             

Validation Accuracy: 48.95%
Saving new best model


                                                               

Epoch [46/48], Loss: 1.4236


                                                             

Validation Accuracy: 49.00%
Saving new best model


                                                               

Epoch [47/48], Loss: 1.4277


                                                             

Validation Accuracy: 48.26%


                                                              

Epoch [48/48], Loss: 1.4229


                                                             

Validation Accuracy: 48.57%


