<a href="https://colab.research.google.com/github/RezuanChowdhuryRifat/SETI-Signal-Classification/blob/main/Swin_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports


In [1]:
import os
import torch
import torchvision
from torchvision import datasets
from torchvision import transforms as T 
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, sampler, random_split
from torchvision import models

In [2]:
!pip install timm 
import timm
from timm.loss import LabelSmoothingCrossEntropy

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm
  Downloading timm-0.6.12-py3-none-any.whl (549 kB)
[K     |████████████████████████████████| 549 kB 5.1 MB/s 
Collecting huggingface-hub
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 54.2 MB/s 
Installing collected packages: huggingface-hub, timm
Successfully installed huggingface-hub-0.11.1 timm-0.6.12


# Dataset

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
%cd /content/drive/MyDrive/dataset

/content/drive/MyDrive/dataset


In [5]:
def get_classes(data_dir):
    all_data = datasets.ImageFolder(data_dir)
    return all_data.classes

In [6]:
train_path = "/content/drive/MyDrive/dataset/train"
test_path = "/content/drive/MyDrive/dataset/test"
valid_path = "/content/drive/MyDrive/dataset/valid"


SIZE = 224  #Resize images
batch_size = 32

train_transform = T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(timm.data.IMAGENET_DEFAULT_MEAN, timm.data.IMAGENET_DEFAULT_STD)
        ])

test_transform = T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(timm.data.IMAGENET_DEFAULT_MEAN, timm.data.IMAGENET_DEFAULT_STD), 
        ])

train_data = datasets.ImageFolder(os.path.join(train_path), transform = train_transform) 
valid_data = datasets.ImageFolder(os.path.join(valid_path), transform = test_transform) 
test_data = datasets.ImageFolder(os.path.join(test_path), transform = test_transform) 

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=1)
valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False, num_workers=1)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=1)


In [7]:
print(len(train_loader), len(valid_loader), len(test_loader))

175 22 22


In [8]:
classes = get_classes(train_path)
print(classes, len(classes))

['brightpixel', 'narrowband', 'narrowbanddrd', 'noise', 'squarepulsednarrowband', 'squiggle', 'squigglesquarepulsednarrowband'] 7


In [9]:
dataloaders = {
    "train": train_loader,
    "val": valid_loader
}
dataset_sizes = {
    "train": len(train_data),
    "val": len(valid_data)
}

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Model

**Download pretrained model**

In [11]:
HUB_URL = "SharanSMenon/swin-transformer-hub:main"
MODEL_NAME = "swin_tiny_patch4_window7_224"
model = torch.hub.load(HUB_URL, MODEL_NAME, pretrained=True)

  "You are about to download and run code from an untrusted repository. In a future release, this won't "
Downloading: "https://github.com/SharanSMenon/swin-transformer-hub/zipball/main" to /root/.cache/torch/hub/main.zip
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth" to /root/.cache/torch/hub/checkpoints/swin_tiny_patch4_window7_224.pth


  0%|          | 0.00/109M [00:00<?, ?B/s]

In [12]:
for param in model.parameters(): #freeze model
    param.requires_grad = False

In [13]:
from torchsummary import summary
summary(model.cuda(), (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 56, 56]           4,704
         LayerNorm-2             [-1, 3136, 96]             192
        PatchEmbed-3             [-1, 3136, 96]               0
           Dropout-4             [-1, 3136, 96]               0
         LayerNorm-5             [-1, 3136, 96]             192
            Linear-6              [-1, 49, 288]          27,936
           Softmax-7            [-1, 3, 49, 49]               0
           Dropout-8            [-1, 3, 49, 49]               0
            Linear-9               [-1, 49, 96]           9,312
          Dropout-10               [-1, 49, 96]               0
  WindowAttention-11               [-1, 49, 96]               0
         Identity-12             [-1, 3136, 96]               0
        LayerNorm-13             [-1, 3136, 96]             192
           Linear-14            [-1, 31

In [14]:
model

SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0): BasicLayer(
      dim=96, input_resolution=(56, 56), depth=2
      (blocks): ModuleList(
        (0): SwinTransformerBlock(
          dim=96, input_resolution=(56, 56), num_heads=3, window_size=7, shift_size=0, mlp_ratio=4.0
          (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            dim=96, window_size=(7, 7), num_heads=3
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=96, out_features=96, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2): LayerNo

In [15]:
model.layers[3]

BasicLayer(
  dim=768, input_resolution=(7, 7), depth=2
  (blocks): ModuleList(
    (0): SwinTransformerBlock(
      dim=768, input_resolution=(7, 7), num_heads=24, window_size=7, shift_size=0, mlp_ratio=4.0
      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): WindowAttention(
        dim=768, window_size=(7, 7), num_heads=24
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
        (softmax): Softmax(dim=-1)
      )
      (drop_path): DropPath(drop_prob=0.091)
      (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate=none)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
     

In [16]:
for param in model.layers[3].parameters(): #unfreeze model
    param.requires_grad = True

summary(model.cuda(), (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 56, 56]           4,704
         LayerNorm-2             [-1, 3136, 96]             192
        PatchEmbed-3             [-1, 3136, 96]               0
           Dropout-4             [-1, 3136, 96]               0
         LayerNorm-5             [-1, 3136, 96]             192
            Linear-6              [-1, 49, 288]          27,936
           Softmax-7            [-1, 3, 49, 49]               0
           Dropout-8            [-1, 3, 49, 49]               0
            Linear-9               [-1, 49, 96]           9,312
          Dropout-10               [-1, 49, 96]               0
  WindowAttention-11               [-1, 49, 96]               0
         Identity-12             [-1, 3136, 96]               0
        LayerNorm-13             [-1, 3136, 96]             192
           Linear-14            [-1, 31

**Add classifier**

In [17]:


n_inputs = model.head.in_features
model.head = nn.Sequential(
    nn.Linear(n_inputs, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, 7)
)
model = model.to(device)
print(model.head)

Sequential(
  (0): Linear(in_features=768, out_features=512, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.3, inplace=False)
  (3): Linear(in_features=512, out_features=7, bias=True)
)


In [18]:
criterion = LabelSmoothingCrossEntropy()
criterion = criterion.to(device)
optimizer = optim.AdamW(model.head.parameters(), lr=0.001)
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.97)

# Training

In [19]:
import sys
from tqdm import tqdm
import time
import copy

def train_model(model, criterion, optimizer, scheduler, num_epochs=10):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print("-"*10)
        
        for phase in ['train', 'val']: # We do training and validation phase per epoch
            if phase == 'train':
                model.train() # model to training mode
            else:
                model.eval() # model to evaluate
            
            running_loss = 0.0
            running_corrects = 0.0
            
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'): # no autograd makes validation go faster
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1) # used for accuracy
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
            if phase == 'train':
                scheduler.step() # step at end of epoch
            
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc =  running_corrects.double() / dataset_sizes[phase]
            
            print("{} Loss: {:.4f} Acc: {:.4f}".format(phase, epoch_loss, epoch_acc))
            
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict()) # keep the best validation accuracy model
        print()
    time_elapsed = time.time() - since # slight error
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print("Best Val Acc: {:.4f}".format(best_acc))
    
    model.load_state_dict(best_model_wts)
    return model

In [21]:
model_ft = train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=20)

Epoch 0/19
----------


100%|██████████| 175/175 [15:03<00:00,  5.16s/it]


train Loss: 1.1277 Acc: 0.7021


100%|██████████| 22/22 [02:19<00:00,  6.33s/it]


val Loss: 1.0019 Acc: 0.7657

Epoch 1/19
----------


100%|██████████| 175/175 [01:23<00:00,  2.10it/s]


train Loss: 1.0574 Acc: 0.7370


100%|██████████| 22/22 [00:09<00:00,  2.26it/s]


val Loss: 1.0312 Acc: 0.7443

Epoch 2/19
----------


100%|██████████| 175/175 [01:22<00:00,  2.13it/s]


train Loss: 1.0262 Acc: 0.7486


100%|██████████| 22/22 [00:09<00:00,  2.27it/s]


val Loss: 1.0262 Acc: 0.7486

Epoch 3/19
----------


100%|██████████| 175/175 [01:20<00:00,  2.19it/s]


train Loss: 1.0104 Acc: 0.7636


100%|██████████| 22/22 [00:09<00:00,  2.31it/s]


val Loss: 0.9868 Acc: 0.7729

Epoch 4/19
----------


100%|██████████| 175/175 [01:19<00:00,  2.21it/s]


train Loss: 0.9973 Acc: 0.7650


100%|██████████| 22/22 [00:09<00:00,  2.32it/s]


val Loss: 0.9427 Acc: 0.8000

Epoch 5/19
----------


100%|██████████| 175/175 [01:18<00:00,  2.22it/s]


train Loss: 0.9928 Acc: 0.7696


100%|██████████| 22/22 [00:09<00:00,  2.32it/s]


val Loss: 0.9317 Acc: 0.7943

Epoch 6/19
----------


100%|██████████| 175/175 [01:19<00:00,  2.21it/s]


train Loss: 0.9891 Acc: 0.7729


100%|██████████| 22/22 [00:09<00:00,  2.33it/s]


val Loss: 0.9213 Acc: 0.7943

Epoch 7/19
----------


100%|██████████| 175/175 [01:18<00:00,  2.22it/s]


train Loss: 0.9723 Acc: 0.7762


100%|██████████| 22/22 [00:10<00:00,  2.09it/s]


val Loss: 0.9258 Acc: 0.7943

Epoch 8/19
----------


100%|██████████| 175/175 [01:18<00:00,  2.22it/s]


train Loss: 0.9785 Acc: 0.7788


100%|██████████| 22/22 [00:09<00:00,  2.30it/s]


val Loss: 0.9214 Acc: 0.8014

Epoch 9/19
----------


100%|██████████| 175/175 [01:19<00:00,  2.21it/s]


train Loss: 0.9689 Acc: 0.7780


100%|██████████| 22/22 [00:09<00:00,  2.31it/s]


val Loss: 0.9578 Acc: 0.7871

Epoch 10/19
----------


100%|██████████| 175/175 [01:18<00:00,  2.22it/s]


train Loss: 0.9732 Acc: 0.7784


100%|██████████| 22/22 [00:09<00:00,  2.33it/s]


val Loss: 0.9285 Acc: 0.7971

Epoch 11/19
----------


100%|██████████| 175/175 [01:19<00:00,  2.20it/s]


train Loss: 0.9623 Acc: 0.7811


100%|██████████| 22/22 [00:09<00:00,  2.29it/s]


val Loss: 0.9069 Acc: 0.8071

Epoch 12/19
----------


100%|██████████| 175/175 [01:19<00:00,  2.21it/s]


train Loss: 0.9481 Acc: 0.7864


100%|██████████| 22/22 [00:09<00:00,  2.29it/s]


val Loss: 0.9131 Acc: 0.7900

Epoch 13/19
----------


100%|██████████| 175/175 [01:20<00:00,  2.19it/s]


train Loss: 0.9519 Acc: 0.7852


100%|██████████| 22/22 [00:09<00:00,  2.29it/s]


val Loss: 0.9165 Acc: 0.8000

Epoch 14/19
----------


100%|██████████| 175/175 [01:19<00:00,  2.21it/s]


train Loss: 0.9565 Acc: 0.7848


100%|██████████| 22/22 [00:09<00:00,  2.32it/s]


val Loss: 0.9392 Acc: 0.8029

Epoch 15/19
----------


100%|██████████| 175/175 [01:19<00:00,  2.21it/s]


train Loss: 0.9525 Acc: 0.7864


100%|██████████| 22/22 [00:10<00:00,  2.10it/s]


val Loss: 0.9260 Acc: 0.7957

Epoch 16/19
----------


100%|██████████| 175/175 [01:19<00:00,  2.20it/s]


train Loss: 0.9515 Acc: 0.7845


100%|██████████| 22/22 [00:10<00:00,  2.18it/s]


val Loss: 0.9095 Acc: 0.7986

Epoch 17/19
----------


100%|██████████| 175/175 [01:19<00:00,  2.20it/s]


train Loss: 0.9364 Acc: 0.7955


100%|██████████| 22/22 [00:09<00:00,  2.30it/s]


val Loss: 0.9097 Acc: 0.8086

Epoch 18/19
----------


100%|██████████| 175/175 [01:19<00:00,  2.20it/s]


train Loss: 0.9464 Acc: 0.7902


100%|██████████| 22/22 [00:09<00:00,  2.30it/s]


val Loss: 0.9348 Acc: 0.7914

Epoch 19/19
----------


100%|██████████| 175/175 [01:19<00:00,  2.20it/s]


train Loss: 0.9449 Acc: 0.7968


100%|██████████| 22/22 [00:09<00:00,  2.29it/s]

val Loss: 0.8954 Acc: 0.8043

Training complete in 45m 40s
Best Val Acc: 0.8086





**Save model**

In [None]:
%cd "/content/drive/MyDrive/dataset/SwinTransformerModel"

torch.save(model_ft.state_dict(), 'model_weights.pth')

/content/drive/MyDrive/dataset/SwinTransformerModel


# Testing

In [22]:
import numpy as np

test_loss = 0.0
class_correct = list(0 for i in range(len(classes)))
class_total = list(0 for i in range(len(classes)))
model_ft.eval()

for data, target in tqdm(test_loader):
    data, target = data.to(device), target.to(device)
    with torch.no_grad(): # turn off autograd for faster testing
        output = model_ft(data)
        loss = criterion(output, target)
    test_loss = loss.item() * data.size(0)
    _, pred = torch.max(output, 1)
    correct_tensor = pred.eq(target.data.view_as(pred))
    correct = np.squeeze(correct_tensor.cpu().numpy())
    if len(target) == 32:
        for i in range(32):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

test_loss = test_loss / len(test_data)
print('Test Loss: {:.4f}'.format(test_loss))
for i in range(len(classes)):
    if class_total[i] > 0:
        print("Test Accuracy of %5s: %2d%% (%2d/%2d)" % (
            classes[i], 100*class_correct[i]/class_total[i], np.sum(class_correct[i]), np.sum(class_total[i])
        ))
    else:
        print("Test accuracy of %5s: NA" % (classes[i]))
print("Test Accuracy of %2d%% (%2d/%2d)" % (
            100*np.sum(class_correct)/np.sum(class_total), np.sum(class_correct), np.sum(class_total)
        ))

100%|██████████| 22/22 [02:25<00:00,  6.61s/it]

Test Loss: 0.0382
Test Accuracy of brightpixel: 78% (78/100)
Test Accuracy of narrowband: 93% (93/100)
Test Accuracy of narrowbanddrd: 65% (65/100)
Test Accuracy of noise: 99% (99/100)
Test Accuracy of squarepulsednarrowband: 79% (79/100)
Test Accuracy of squiggle: 78% (78/100)
Test Accuracy of squigglesquarepulsednarrowband: 75% (54/72)
Test Accuracy of 81% (546/672)



