# Swin-Transformer model

paper : https://arxiv.org/abs/2103.14030

In [1]:
import os
import torch
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch.nn as nn

import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

from transformers import SwinForImageClassification
from transformers import SwinForImageClassification, SwinConfig
from transformers import Swinv2Config, Swinv2Model
from sklearn.metrics import confusion_matrix, classification_report

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [12]:
data_dir = './../../dataset/train'  
run_folder = 'results/run4'        
batch_size = 16
num_epochs = 10
image_size = 224 
if not os.path.exists(run_folder):
    os.makedirs(run_folder)

## Transformation

In [13]:
# Random crop for training, Resize for validation
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(image_size),
    #transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

## dataset prepartion

In [14]:
#  dataset
dataset = datasets.ImageFolder(root=data_dir)

dataset.classes.sort(key=lambda x: int(x))  
label_to_class_name = {idx: class_name for idx, class_name in enumerate(dataset.classes)}

# Split the dataset 
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_dataset.dataset.transform = train_transform
val_dataset.dataset.transform = val_transform

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


## Modified  Swin

In [15]:
# class CustomSwinTransformer(nn.Module):
#     def __init__(self, num_classes):
#         super(CustomSwinTransformer, self).__init__()
#         self.swin_transformer = create_model('swin_tiny_patch4_window7_224', pretrained=True)
#         # Replace last layer
#         self.swin_transformer.head = nn.Linear(self.swin_transformer.head.in_features, num_classes)
#         self.layer_norm = nn.LayerNorm(normalized_shape=num_classes)

#     def forward(self, x):
#         x = self.swin_transformer(x)
#         return self.layer_norm(x)  # layer normalization
# model = CustomSwinTransformer(num_classes=len(train_dataset.classes))
# model.to(device)

## model

In [16]:
model = SwinForImageClassification.from_pretrained(
    'microsoft/swin-base-patch4-window7-224',
    num_labels=5,
    ignore_mismatched_sizes=True,

)
model.to(device)

Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-base-patch4-window7-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([5]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 1024]) in the checkpoint and torch.Size([5, 1024]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


SwinForImageClassification(
  (swin): SwinModel(
    (embeddings): SwinEmbeddings(
      (patch_embeddings): SwinPatchEmbeddings(
        (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): SwinEncoder(
      (layers): ModuleList(
        (0): SwinStage(
          (blocks): ModuleList(
            (0-1): 2 x SwinLayer(
              (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (attention): SwinAttention(
                (self): SwinSelfAttention(
                  (query): Linear(in_features=128, out_features=128, bias=True)
                  (key): Linear(in_features=128, out_features=128, bias=True)
                  (value): Linear(in_features=128, out_features=128, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
                )
                (output): SwinSelfO

## loss and optimizer functions

In [17]:
#class_counts = [414,478,856,1034,426] # aug-mix
class_counts = [282,320,575,700,298]
total_samples = sum(class_counts)
class_weights = [total_samples / (len(class_counts) * count) for count in class_counts]
class_weights = torch.FloatTensor(class_weights).to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

## Model Training

In [18]:
train_loss_history = []
val_loss_history = []
train_acc_history = []
val_acc_history = []

In [None]:
# early stopping init
best_val_loss = float('inf')
patience = 5  # Number of epochs to wait
patience_counter = 0

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # Training phase
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
    
        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_loss = running_loss / len(train_loader)
    train_acc = 100 * correct / total
    train_loss_history.append(train_loss)
    train_acc_history.append(train_acc)

    # Validation phase
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    val_loss /= len(val_loader)
    val_acc = 100 * val_correct / val_total
    val_loss_history.append(val_loss)
    val_acc_history.append(val_acc)

    # Print metrics for this epoch
    print(f'Epoch {epoch + 1}/{num_epochs}, '
          f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
          f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

    # early stopping cehck
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0  
        
    else:
        patience_counter += 1
    
    if patience_counter >= patience:
        print(f'Early stopping triggered after {epoch + 1} epochs.')
        break

    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=dataset.classes, yticklabels=dataset.classes)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(f'Confusion Matrix (Epoch {epoch + 1})')
    plt.show()

## ploting loss and acc graph

In [None]:
# Plot accuracy and loss graphs
plt.figure()
plt.plot(train_acc_history, label='Train Accuracy')
plt.plot(val_acc_history, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Validation Accuracy')
plt.legend()
# plt.savefig(os.path.join(run_folder, 'accuracy_lr7.png'))
# plt.close()

plt.figure()
plt.plot(train_loss_history, label='Train Loss')
plt.plot(val_loss_history, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
#plt.savefig(os.path.join(run_folder, 'loss_lr7.png'))
# plt.close()

plt.show()

## classifcation report

In [None]:
# classification report
print("Final Classification Report:")
report = classification_report(all_labels, all_preds, target_names=dataset.classes)
print(report)

In [28]:
report_path = os.path.join(run_folder, 'cls_report_val5.txt')
with open(report_path, 'w') as f:
    f.write(report)

print(f"Classification report saved to {report_path}.")

Classification report saved to results/run4\cls_report_val5.txt.


In [48]:
cm = confusion_matrix(all_labels, all_preds)

#print("Confusion Matrix:\n", cm)
if cm.sum() > 0:
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", 
                 xticklabels=dataset.classes, 
                 yticklabels=dataset.classes)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(f'Confusion Matrix (Epoch {epoch + 1})')
    plt.savefig(os.path.join('cls_report_val5.png'), bbox_inches='tight')
    plt.close() 

Confusion Matrix:
 [[11  1  1  0  0]
 [ 1 13  3  1  1]
 [ 1  8 14  1  2]
 [ 0  2  6 29  0]
 [ 1  1  1  1 10]]


In [49]:
# model weights
model_save_path = os.path.join(run_folder, 'swin_run-val5.pth')
torch.save(model.state_dict(), model_save_path)
print(f'Model weights saved to {model_save_path}')

Model weights saved to results/run4\swin_run-val5.pth


## test phase

In [50]:
import os
import pandas as pd
from PIL import Image
from torchvision import transforms
import torch
import torchvision.models as models

In [51]:
model = SwinForImageClassification.from_pretrained(
    'microsoft/swin-tiny-patch4-window7-224',
    num_labels=5,
    ignore_mismatched_sizes=True
)

# #Load trained model
#weights_path = run_folder + '/swin_run-0.66.pth'
#model.load_state_dict(torch.load(weights_path))

model.to(device).eval()

Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-tiny-patch4-window7-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([5]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([5, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


SwinForImageClassification(
  (swin): SwinModel(
    (embeddings): SwinEmbeddings(
      (patch_embeddings): SwinPatchEmbeddings(
        (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): SwinEncoder(
      (layers): ModuleList(
        (0): SwinStage(
          (blocks): ModuleList(
            (0-1): 2 x SwinLayer(
              (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
              (attention): SwinAttention(
                (self): SwinSelfAttention(
                  (query): Linear(in_features=96, out_features=96, bias=True)
                  (key): Linear(in_features=96, out_features=96, bias=True)
                  (value): Linear(in_features=96, out_features=96, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
                )
                (output): SwinSelfOutput(
  

In [52]:
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

#test images
test_dir = './../../dataset/test'
test_images = [f for f in os.listdir(test_dir) if f.endswith('.jpg')]

In [53]:
test_images.sort(key=lambda x: int(os.path.splitext(x)[0]))  # Assuming filenames are numeric

filenames = []
predictions = []

for image_name in test_images:
    image_path = os.path.join(test_dir, image_name)
    image = Image.open(image_path).convert('RGB')  # Ensure image is in RGB mode
    image = test_transform(image)
    image = image.unsqueeze(0).to(device)  # Add batch dimension and move to device

    # Predict the class
    with torch.no_grad():
        outputs = model(image)
        logits = outputs.logits
        _, predicted = torch.max(logits, 1)
        class_id = predicted.item()

    filenames.append(os.path.splitext(image_name)[0])  # Remove .jpg from filename
    predictions.append(class_id + 1)  # Map class_id to 1-based index if needed

In [54]:
#csv
results_df = pd.DataFrame({'ID': filenames, 'Predictions': predictions})
results_df.to_csv('swin-val5.csv', index=False)