In [1]:
import timm
import torch
import torch.nn as nn

print("Using Swin Transformer model")

model_name = "swin_base_patch4_window7_224"
pretrained = True
num_classes = 5
model = timm.create_model(model_name, pretrained=pretrained)
model.head.fc = nn.Linear(model.head.fc.in_features, num_classes)
    


Using Swin Transformer model


In [9]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {trainable_params}")

for k in range(4):
    print(f"Layer {k} parameters: {sum(p.numel() for p in model.layers[k].parameters())/1e6:.2f}M")

Number of trainable parameters: 86748349
Layer 0 parameters: 0.40M
Layer 1 parameters: 1.71M
Layer 2 parameters: 57.32M
Layer 3 parameters: 27.30M


In [7]:
model.layers[3]


SwinTransformerStage(
  (downsample): PatchMerging(
    (norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
    (reduction): Linear(in_features=2048, out_features=1024, bias=False)
  )
  (blocks): Sequential(
    (0): SwinTransformerBlock(
      (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (attn): WindowAttention(
        (qkv): Linear(in_features=1024, out_features=3072, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
        (softmax): Softmax(dim=-1)
      )
      (drop_path1): DropPath(drop_prob=0.096)
      (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (fc2): Linear(in_features=4096, out_

In [3]:
from peft import get_peft_model, LoraConfig, TaskType
task_type=TaskType.FEATURE_EXTRACTION


r=16
lora_alpha=32
lora_dropout=0.2
target_modules = ["qkv"]  # Default for transformer models

lora_config = LoraConfig(
            r=r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            target_modules=target_modules,  # Default for transformer models
        )

2024-12-13 16:13:01.952549: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1734102781.963459 2347249 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1734102781.966722 2347249 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-13 16:13:01.979700: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
model = get_peft_model(model, lora_config)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {trainable_params}")

x = torch.randn(1, 3, 224, 224)
out = model(x)
print(out.shape)

Number of trainable parameters: 770048
torch.Size([1, 5])


In [None]:
class SwinWithLoRA(nn.Module):
    def __init__(self, base_model, lora_config):
        super().__init__()
        self.base_model = get_peft_model(base_model, lora_config)

    def forward(self, inputs_ids):
        # Forward pass directly through the base model
        return self.base_model(inputs_ids)


In [37]:
model = timm.create_model(model_name, pretrained=pretrained)
model.head.fc = nn.Linear(model.head.fc.in_features, num_classes)

modelLora = SwinWithLoRA(model, lora_config)

In [35]:
x = torch.randn(1, 3, 224, 224)
out = modelLora(x)
print(out.shape)
trainable_params = sum(p.numel() for p in modelLora.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {trainable_params}")

torch.Size([1, 5])
Number of trainable parameters: 770048
