<a href="https://colab.research.google.com/github/arun-arunisto/OpenCVTutorialAbel/blob/main/UsingPretrainedSwinTransformerModelForImageClassification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
#transformation pipeline
transform = transforms.Compose({
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
})

In [3]:
#loading the dataset
train_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/data/brain_tumor_dataset/train', transform=transform)
test_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/data/brain_tumor_dataset/test', transform=transform)

In [4]:
#creating dataloader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [5]:
#using pretrained swin transformer
from transformers import SwinForImageClassification

#loading pretrained swin transformer model
model = SwinForImageClassification.from_pretrained(
    'microsoft/swin-tiny-patch4-window7-224',
    num_labels=2,
    ignore_mismatched_sizes=True,
)

config.json:   0%|          | 0.00/71.8k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/113M [00:00<?, ?B/s]

Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-tiny-patch4-window7-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
#freezing all layers except the final classification head
for param in model.parameters():
    param.requires_grad = False

In [None]:
#unfreezing the classification layer
for param in model.classifier.parameters():
    param.requires_grad = True

In [None]:
import torch.optim as optim
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR

In [None]:
#setting up the optimizer and loss function
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

In [10]:
#gpu -> neural will be more faster than cpus
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

SwinForImageClassification(
  (swin): SwinModel(
    (embeddings): SwinEmbeddings(
      (patch_embeddings): SwinPatchEmbeddings(
        (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): SwinEncoder(
      (layers): ModuleList(
        (0): SwinStage(
          (blocks): ModuleList(
            (0-1): 2 x SwinLayer(
              (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
              (attention): SwinAttention(
                (self): SwinSelfAttention(
                  (query): Linear(in_features=96, out_features=96, bias=True)
                  (key): Linear(in_features=96, out_features=96, bias=True)
                  (value): Linear(in_features=96, out_features=96, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
                )
                (output): SwinSelfOutput(
  

In [None]:
#training loop
#10
for epoch in range(10):
  model.train()
  running_loss = 0.0
  correct = 0
  total = 0
  for inputs, labels in train_loader:
    inputs, labels = inputs.to(device), labels.to(device)
    #zero the parameter gradients
    optimizer.zero_grad()
    #forward pass
    outputs = model(inputs).logits
    #calculate loss
    loss = criterion(outputs, labels)
    #backward pass
    loss.backward()
    #update weights
    optimizer.step()
    #calculating the accuracy
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
    running_loss += loss.item()
  #printing the training results
  print(f"Epoch [{epoch+1}/10], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {100*correct/total:.2f}%")

  #validation
  model.eval()
  val_correct = 0
  val_total = 0
  with torch.no_grad():
    for val_inputs, val_labels in test_loader:
      val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
      val_outputs = model(val_inputs).logits
      _, val_predicted = torch.max(val_outputs, 1)
      val_total += val_labels.size(0)
      val_correct += (val_predicted == val_labels).sum().item()
  print(f"Validation Accuracy: {100*val_correct/val_total:.2f}%")

Epoch [1/10], Loss: 0.6856, Accuracy: 55.89%
Validation Accuracy: 55.41%
Epoch [2/10], Loss: 0.6462, Accuracy: 67.68%
Validation Accuracy: 70.27%
Epoch [3/10], Loss: 0.6137, Accuracy: 73.18%
Validation Accuracy: 68.24%
Epoch [4/10], Loss: 0.5868, Accuracy: 75.53%
Validation Accuracy: 76.35%
Epoch [5/10], Loss: 0.5571, Accuracy: 79.35%
Validation Accuracy: 77.03%
Epoch [6/10], Loss: 0.5396, Accuracy: 79.46%
Validation Accuracy: 77.70%
Epoch [7/10], Loss: 0.5156, Accuracy: 80.25%
Validation Accuracy: 77.70%
Epoch [8/10], Loss: 0.5018, Accuracy: 80.70%
Validation Accuracy: 77.70%
Epoch [9/10], Loss: 0.4971, Accuracy: 80.58%
Validation Accuracy: 77.70%
Epoch [10/10], Loss: 0.4734, Accuracy: 83.28%
Validation Accuracy: 76.35%


In [None]:
#saving the model
torch.save(model.state_dict(), "/content/drive/MyDrive/AbelFolder/swintransformer_model.pth")

## Calculating accuracy of the Swin Transformer Model

In [12]:
from PIL import Image

In [13]:
#loading the saved model
model.load_state_dict(torch.load("/content/drive/MyDrive/AbelFolder/swintransformer_model.pth"))

  model.load_state_dict(torch.load("/content/drive/MyDrive/AbelFolder/swintransformer_model.pth"))


<All keys matched successfully>

In [14]:
model.eval()

SwinForImageClassification(
  (swin): SwinModel(
    (embeddings): SwinEmbeddings(
      (patch_embeddings): SwinPatchEmbeddings(
        (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): SwinEncoder(
      (layers): ModuleList(
        (0): SwinStage(
          (blocks): ModuleList(
            (0-1): 2 x SwinLayer(
              (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
              (attention): SwinAttention(
                (self): SwinSelfAttention(
                  (query): Linear(in_features=96, out_features=96, bias=True)
                  (key): Linear(in_features=96, out_features=96, bias=True)
                  (value): Linear(in_features=96, out_features=96, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
                )
                (output): SwinSelfOutput(
  

In [15]:
#opening the image
img = Image.open("/content/drive/MyDrive/data/brain_tumor_dataset/train/healthy/0000.jpg")
img = transform(img).unsqueeze(0).to(device)

In [17]:
#predicting healthy
output = model(img).logits
_, predicted = torch.max(output, 1)
print(f"Predicted class: {predicted.item()}")

Predicted class: 0


In [18]:
#predicting the tumor
img = Image.open("/content/drive/MyDrive/data/brain_tumor_dataset/train/tumor/00004.jpg")
img = transform(img).unsqueeze(0).to(device)

In [19]:
#predicting healthy
output = model(img).logits
_, predicted = torch.max(output, 1)
print(f"Predicted class: {predicted.item()}")

Predicted class: 1


In [20]:
#healthy folder
path = "/content/drive/MyDrive/data/brain_tumor_dataset/test/healthy"

In [21]:
import os

In [22]:
files = os.listdir(path)

In [24]:
for f in files:
  try:
    img = Image.open(os.path.join(path, f))
    img = transform(img).unsqueeze(0).to(device)
    output = model(img).logits
    _, predicted = torch.max(output, 1)
    print(f"Predicted Class: {predicted.item()} | filename: {f} | Actual Class: 0")
  except Exception as e:
    print(e)

Predicted Class: 0 | filename: 0796.jpg | Actual Class: 0
Predicted Class: 0 | filename: 0676.jpg | Actual Class: 0
Predicted Class: 1 | filename: 0698.jpg | Actual Class: 0
Predicted Class: 1 | filename: 0601.jpg | Actual Class: 0
Predicted Class: 0 | filename: 0861.jpg | Actual Class: 0
Predicted Class: 1 | filename: 0615.jpg | Actual Class: 0
Predicted Class: 0 | filename: 0874.jpg | Actual Class: 0
Predicted Class: 0 | filename: 0820.jpg | Actual Class: 0
Predicted Class: 1 | filename: 0785.jpg | Actual Class: 0
Predicted Class: 0 | filename: 0792.jpg | Actual Class: 0
Predicted Class: 0 | filename: 0731.jpg | Actual Class: 0
Predicted Class: 0 | filename: 0762.jpg | Actual Class: 0
Predicted Class: 1 | filename: 0710.jpg | Actual Class: 0
Predicted Class: 0 | filename: 0858.jpg | Actual Class: 0
Predicted Class: 0 | filename: 0691.jpg | Actual Class: 0
Predicted Class: 0 | filename: 0791.jpg | Actual Class: 0
Predicted Class: 1 | filename: 0639.jpg | Actual Class: 0
Predicted Clas

In [27]:
#function for calculating accuracy
def calculate_accuracy(model, img_path, actual_class):
  files = os.listdir(img_path)
  total_images = len(files)
  predicted_ones = 0
  for f in files:
    try:
      img = Image.open(os.path.join(path, f))
      img = transform(img).unsqueeze(0).to(device)
      output = model(img).logits
      _, predicted = torch.max(output, 1)
      if int(predicted.item()) == int(actual_class):
        predicted_ones += 1
    except Exception as e:
      continue
  accuracy_score = (predicted_ones/total_images)*100
  return accuracy_score

In [28]:
img_path = "/content/drive/MyDrive/data/brain_tumor_dataset/test/healthy"
actual_class = 0
print("Accuracy Score:",calculate_accuracy(model, img_path, actual_class))

Accuracy Score: 52.054794520547944


In [31]:
img_path = "/content/drive/MyDrive/data/brain_tumor_dataset/test/tumor"
actual_class = 1
print("Accuracy Score:",calculate_accuracy(model, img_path, actual_class))

Accuracy Score: 2.631578947368421


Accuracy Score for healthy: **52.04%**

Accuracy Score for tumor: **2.63%**