In [None]:
import torch
import torch.nn as nn
from transformers import ViTModel, PreTrainedModel, AutoConfig
from torchvision import transforms
from PIL import Image

model_checkpoint = "facebook/deit-base-patch16-224"

class SkipViT(PreTrainedModel):
    def __init__(self, config, drop_layers=[6, 8], drop_ratio=0.35):
        """
        A custom Vision Transformer (ViT) model with token dropping functionality.

        Args:
            config: The configuration object for the ViT model.
            drop_layers: List of encoder layers where token dropping is applied.
            drop_ratio: Proportion of tokens to drop based on attention scores.
        """
        super().__init__(config)
        self.vit = ViTModel(config)
        self.drop_layers = drop_layers
        self.drop_ratio = drop_ratio
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, pixel_values, labels=None):
        """
        Forward pass of the SkipViT model.

        Args:
            pixel_values: Input images processed into tensors.
            labels: Optional ground-truth labels for calculating loss.

        Returns:
            A dictionary containing loss (if labels are provided) and logits.
        """
        # Step 1: Use the ViT model to get embeddings and attentions
        outputs = self.vit(pixel_values, output_attentions=True)
        x = outputs.last_hidden_state  # Shape: (batch_size, num_tokens, hidden_size)
        attentions = outputs.attentions  # List of attention matrices

        # Step 2: Apply token dropping based on attention scores
        for i, attn_scores in enumerate(attentions):
            if i in self.drop_layers:
                x = self.drop_tokens(x, attn_scores)

        # Step 3: Use the [CLS] token for classification
        cls_output = x[:, 0]  # Shape: (batch_size, hidden_size)
        logits = self.classifier(cls_output)  # Shape: (batch_size, num_labels)

        # Step 4: Calculate loss if labels are provided
        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)
        return {"loss": loss, "logits": logits}

    def drop_tokens(self, x, attn_scores):
        """
        Drops the least important tokens based on attention scores.

        Args:
            x: The hidden states from the transformer.
            attn_scores: Attention scores for each token.

        Returns:
            The hidden states with dropped tokens removed.
        """
        batch_size, num_tokens, hidden_size = x.shape

        # Step 1: Average attention scores across heads
        avg_scores = attn_scores.mean(dim=1)  # Shape: (batch_size, num_tokens, num_tokens)

        # Step 2: Calculate the average score for each token
        token_scores = avg_scores.mean(dim=-1)  # Shape: (batch_size, num_tokens)

        # Step 3: Determine how many tokens to keep
        num_keep = int((1 - self.drop_ratio) * num_tokens)

        # Ensure num_keep is not greater than the current number of tokens
        num_keep = min(num_keep, num_tokens)

        # Step 4: Get indices of the top tokens to keep
        # Adjust the token scores to be within the current reduced token range
        topk_indices = token_scores.topk(num_keep, dim=1).indices  # Shape: (batch_size, num_keep)

        # Step 5: Adjust indices if they exceed the current token length
        topk_indices = torch.clamp(topk_indices, max=num_tokens - 1)

        # Step 6: Gather the top tokens based on indices
        topk_indices = topk_indices.unsqueeze(-1).expand(-1, -1, hidden_size)
        x = torch.gather(x, 1, topk_indices)  # Shape: (batch_size, num_keep, hidden_size)
        return x

In [None]:
from transformers import Trainer

# Load the pre-trained model and update it for Food-101 (101 classes)
config = AutoConfig.from_pretrained('facebook/deit-base-patch16-224', num_labels=101)
model = SkipViT(config)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/69.6k [00:00<?, ?B/s]

In [None]:
# Load the pre-trained weights into the custom model
pretrained_vit = ViTModel.from_pretrained(model_checkpoint)
model.vit.load_state_dict(pretrained_vit.state_dict(), strict=False)

pytorch_model.bin:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at facebook/deit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [None]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )

In [None]:
print_trainable_parameters(model)

trainable params: 86466917 || all params: 86466917 || trainable%: 100.00


In [None]:
# Initialize the classifier with random weights
model.classifier.reset_parameters()

In [None]:
for name, module in model.named_modules():
    print(name, ":", module)

 : SkipViT(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_feature

In [None]:
dummy_image = torch.randn(1, 3, 224, 224)  # Simulating a random image input

# Step 3: Run the forward pass
model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    output = model(pixel_values=dummy_image)

In [None]:
print("Logits:", output['logits'])
print("Logits shape:", output['logits'].shape)

Logits: tensor([[ 0.0947,  0.1992, -0.0480,  0.1394, -0.0107, -0.4648, -0.1788, -0.0120,
          0.0716,  0.0691,  0.1518, -0.0473, -0.0036, -0.4992,  0.1172, -0.1914,
          0.1567,  0.3927, -0.0686,  0.1564,  0.4542, -0.2070, -0.2059,  0.1021,
         -0.0523, -0.5833, -0.2759,  0.1688,  0.4489,  0.5220, -0.4654, -0.0900,
          0.0923, -0.2256, -0.3395, -0.1014,  0.3485, -0.6389, -0.0063,  0.3212,
          0.2804, -0.0295,  0.4600,  0.1507,  0.8533,  0.1982, -0.5711,  0.1475,
          0.5616, -0.2410, -0.1645,  0.1831,  0.3731,  0.0386,  0.3688, -0.0527,
         -0.3328,  0.3545, -0.1409,  0.3003,  0.4427,  0.1324,  0.0390, -0.1367,
         -0.3547, -0.1218, -0.0832,  0.5584,  0.0710,  0.2988, -0.0242,  0.2345,
         -0.0186, -0.1516,  0.3143, -0.3109,  0.3925,  0.0802,  0.3409, -0.2752,
          0.1691, -0.3331, -0.2018,  0.0398,  0.1389,  0.1329,  0.3719, -0.3374,
          0.1047,  0.1135, -0.1718,  0.2887,  0.1315,  0.0395, -0.0021,  0.1144,
         -0.0743, -0

In [None]:
dummy_image = torch.randn(1, 3, 224, 224)  # Simulating a random image input
dummy_label = torch.tensor([5])  # Assuming class 5 is a valid label for the dataset

# Step 3: Set model to training mode
model.train()

# Step 4: Forward pass
output = model(pixel_values=dummy_image, labels=dummy_label)
logits = output['logits']
loss = output['loss']

print(f"Logits: {logits}")
print(f"Loss: {loss}")

# Step 5: Perform a backward pass to compute gradients
loss.backward()

# Step 6: Check if gradients have been computed
for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"Gradient computed for {name}, shape: {param.grad.shape}")
    else:
        print(f"No gradient computed for {name}")

Logits: tensor([[-0.0140,  0.4286, -0.1043,  0.0278, -0.1569, -0.4503, -0.2652, -0.1300,
         -0.1995,  0.0204, -0.0244,  0.0679, -0.2699, -0.6569,  0.0715, -0.1789,
         -0.0305,  0.3496, -0.1724,  0.0099,  0.3932, -0.1806, -0.3071, -0.2582,
          0.1007, -0.3926, -0.3016,  0.2469,  0.4342,  0.3697, -0.3665,  0.0834,
         -0.0308, -0.3000, -0.4653, -0.0582,  0.0279, -0.3575,  0.0878,  0.2349,
          0.2837,  0.1378,  0.4317,  0.3510,  0.5176,  0.3630, -0.3817,  0.0481,
          0.4485, -0.1022, -0.2741,  0.4111, -0.0126, -0.1410,  0.1428,  0.1053,
         -0.0410,  0.2472, -0.0157,  0.0740,  0.3274,  0.0687,  0.2138, -0.1090,
         -0.3189,  0.0789, -0.3397,  0.4796, -0.0918,  0.5527,  0.1880,  0.1900,
          0.0082, -0.2682,  0.3098, -0.2226,  0.4184, -0.0014,  0.0713, -0.2546,
          0.1437, -0.3042, -0.4242, -0.0189,  0.1320,  0.3078,  0.1638, -0.1019,
          0.3246, -0.1904, -0.1408,  0.0590,  0.1441,  0.1964,  0.2076,  0.0293,
          0.0180,  0

In [None]:
!pip install datasets -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/480.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m27.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/179.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.3/179.3 kB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/134.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m14.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch

# Load the Food-101 dataset
dataset = load_dataset("food101")

# Define image transforms for training and validation
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define a custom dataset class
class Food101Dataset(Dataset):
    def __init__(self, dataset, transforms=None):
        self.dataset = dataset
        self.transforms = transforms

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Load the image and label
        image = self.dataset[idx]["image"].convert("RGB")
        label = self.dataset[idx]["label"]

        # Apply transforms if specified
        if self.transforms:
            image = self.transforms(image)

        return {"pixel_values": image, "labels": torch.tensor(label)}

# Create training and validation datasets
train_dataset = Food101Dataset(dataset["train"], transforms=train_transforms)
val_dataset = Food101Dataset(dataset["validation"], transforms=val_transforms)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

train-00000-of-00008.parquet:   0%|          | 0.00/490M [00:00<?, ?B/s]

train-00001-of-00008.parquet:   0%|          | 0.00/464M [00:00<?, ?B/s]

train-00002-of-00008.parquet:   0%|          | 0.00/472M [00:00<?, ?B/s]

train-00003-of-00008.parquet:   0%|          | 0.00/464M [00:00<?, ?B/s]

train-00004-of-00008.parquet:   0%|          | 0.00/475M [00:00<?, ?B/s]

train-00005-of-00008.parquet:   0%|          | 0.00/470M [00:00<?, ?B/s]

train-00006-of-00008.parquet:   0%|          | 0.00/478M [00:00<?, ?B/s]

train-00007-of-00008.parquet:   0%|          | 0.00/486M [00:00<?, ?B/s]

validation-00000-of-00003.parquet:   0%|          | 0.00/423M [00:00<?, ?B/s]

validation-00001-of-00003.parquet:   0%|          | 0.00/413M [00:00<?, ?B/s]

validation-00002-of-00003.parquet:   0%|          | 0.00/426M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/75750 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/25250 [00:00<?, ? examples/s]

In [None]:
import torch.nn as nn
from transformers import AutoConfig
from torch.optim import AdamW
from transformers import get_scheduler

model = model.to("cuda")

# Define optimizer and learning rate scheduler
optimizer = AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)
num_epochs = 5
total_steps = len(train_loader) * num_epochs
scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=total_steps)

# Define loss function
criterion = nn.CrossEntropyLoss()

In [None]:
from tqdm import tqdm
import torch
from sklearn.metrics import accuracy_score

def train_one_epoch(model, train_loader, optimizer, scheduler, criterion, device):
    """
    Trains the model for one epoch with a progress bar.
    """
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc="Training", leave=False)

    for batch in progress_bar:
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # Forward pass
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs["loss"]

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(train_loader)
    return avg_loss

def evaluate(model, val_loader, criterion, device):
    """
    Evaluates the model with a progress bar.
    """
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    progress_bar = tqdm(val_loader, desc="Evaluating", leave=False)
    with torch.no_grad():
        for batch in progress_bar:
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            # Forward pass
            outputs = model(pixel_values=pixel_values, labels=labels)
            logits = outputs["logits"]
            loss = criterion(logits, labels)

            total_loss += loss.item()

            # Collect predictions and labels for accuracy computation
            preds = torch.argmax(logits, dim=-1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            progress_bar.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(val_loader)
    accuracy = accuracy_score(all_labels, all_preds)
    return avg_loss, accuracy

# Training loop with progress bar for each epoch
def train_model(model, train_loader, val_loader, optimizer, scheduler, criterion, device, num_epochs):
    best_accuracy = 0

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")

        # Training
        train_loss = train_one_epoch(model, train_loader, optimizer, scheduler, criterion, device)
        print(f"Training Loss: {train_loss:.4f}")

        # Evaluation
        val_loss, val_accuracy = evaluate(model, val_loader, criterion, device)
        print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

        # Save the best model
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            torch.save(model.state_dict(), "best_skipvit_model.pth")
            print("Model saved!")

    print("Training complete.")

In [None]:
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

import warnings
warnings.filterwarnings("ignore")

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
criterion = torch.nn.CrossEntropyLoss()

train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    criterion=criterion,
    device=device,
    num_epochs=10
)

Epoch 1/10




Training Loss: 2.1719




Validation Loss: 0.8900, Validation Accuracy: 0.7884
Model saved!
Epoch 2/10




Training Loss: 1.1172




Validation Loss: 0.6306, Validation Accuracy: 0.8390
Model saved!
Epoch 3/10




Training Loss: 0.9224




Validation Loss: 0.5424, Validation Accuracy: 0.8571
Model saved!
Epoch 4/10




Training Loss: 0.8081




Validation Loss: 0.4989, Validation Accuracy: 0.8668
Model saved!
Epoch 5/10




Training Loss: 0.7490




Validation Loss: 0.4816, Validation Accuracy: 0.8715
Model saved!
Epoch 6/10




Training Loss: 0.7287




Validation Loss: 0.4816, Validation Accuracy: 0.8715
Epoch 7/10




Training Loss: 0.7215




Validation Loss: 0.4816, Validation Accuracy: 0.8715
Epoch 8/10




Training Loss: 0.7257




Validation Loss: 0.4816, Validation Accuracy: 0.8715
Epoch 9/10




Training Loss: 0.7245




Validation Loss: 0.4816, Validation Accuracy: 0.8715
Epoch 10/10




Training Loss: 0.7216


                                                                          

Validation Loss: 0.4816, Validation Accuracy: 0.8715
Training complete.


