In [2]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [3]:
train_dir = "TRAIN_PATH"
test_dir = "TEST_PATH"

BATCH_SIZE = 64

In [6]:
train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((32, 32)),
    transforms.RandomRotation(10),
    transforms.RandomAffine(
        degrees=0,
        translate=(0.08, 0.08)
    ),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])


val_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [None]:
val_ratio = 0.15

full_train_ds = datasets.ImageFolder(train_dir, transform=train_transform)
total_size = len(full_train_ds)
val_size = int(val_ratio * total_size)
train_size = total_size - val_size

generator = torch.Generator().manual_seed(42)
train_ds, val_ds = torch.utils.data.random_split(full_train_ds, [train_size, val_size], generator=generator)

#fro training
train_ds.dataset.transform = train_transform
val_ds.dataset.transform = val_transform

#coree testing dfata
test_ds = datasets.ImageFolder(test_dir,val_transform)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)


In [None]:
import json

with open ("class_to_idx.json", "w") as f:
  json.dump(train_ds.dataset.class_to_idx, f)

In [None]:
len(train_ds)
class_names = full_train_ds.classes
print(len(class_names))

In [None]:
def imshow(inp, title=None):
    """Display image for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0)) # Convert from (Channels, Height, Width) to (Height, Width, Channels) this is pytorch style
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title:
        plt.title(title)
    plt.pause(0.001)

inputs, classes = next(iter(train_loader))

out = torchvision.utils.make_grid(inputs[:4]) # 4 imagess
plt.figure(figsize=(10, 5))
imshow(out, title=[class_names[x] for x in classes[:4]])

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)


class SpatialSelfAttention(nn.Module):
    def __init__(self, in_channels,attn_channels=64):
        super().__init__()

        self.query_conv = nn.Conv2d(in_channels, attn_channels, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, attn_channels, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1) # Added value_conv which was missing and caused `value_conv` not defined error

        #learnable scaling parameter
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self,x):

        """
        x: [B × 256 × 8 × 8]  input feature map from stage 3
        return: out: [B × 256 × 8 × 8] attention value
        """
        B,C,H,W = x.shape  # getting batchsize, channel , height and weight from x --> input from stage 3

        N = H*W

        # Q,K,V projection
        Q = self.query_conv(x)  # [B × 64 × 8 × 8]
        K = self.key_conv(x)    # [B × 64 × 8 × 8]
        V = self.value_conv(x)  # [B × 256 × 8 × 8]

        # Flatten spatial dimensions
        Q = Q.view(B, -1, N).permute(0, 2, 1)  # [B × 64 × 64] Transpose of Q
        K = K.view(B, -1, N)                   # [B × 64 × 64]
        V = V.view(B, -1, N).permute(0, 2, 1)  # [B × 64 × 256]

        # Attention matrix
        attention = torch.bmm(Q, K)            # [B × 64 × 64]
        attention = F.softmax(attention, dim=-1)

        # Apply attention to V
        out = torch.bmm(attention, V)          # [B × 64 × 256]
        # Reshape back
        out = out.permute(0, 2, 1).contiguous()
        out = out.view(B, C, H, W)              # [B × 256 × 8 × 8]

        #  Residual fusion
        out = self.gamma * out + x

        return out



class Basnet(nn.Module):
    def __init__(self, num_classes=46):
        super().__init__()
        self.attention = SpatialSelfAttention(in_channels=256) # Instantiate SpatialSelfAttention

        #  Stem (Bridge between input image and residual block )  just increases the feature map
        self.stem = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )

        #  Residual Stages
        self.stage1 = nn.Sequential(
            ResidualBlock(64, 64),
            ResidualBlock(64, 64),   #64x32x32
        )

        self.stage2 = nn.Sequential(
            ResidualBlock(64, 128,2),
            ResidualBlock(128, 128),   #128x16x16
        )

        self.stage3 = nn.Sequential(
            ResidualBlock(128, 256,2),
            ResidualBlock(256, 256),  #256x8x8
        )

        self.stage4 = nn.Sequential(
            ResidualBlock(256, 512,2),  

        )

        self.gap = nn.AdaptiveAvgPool2d(1)

        #  Classifier
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.attention(x) #at 8x8 feature map
        x = self.stage4(x)  # [B,512,4,4]
        x = self.gap(x)    # Global Average Pooling to [B,512,1,1]
        x = x.view(x.size(0), -1)  # Flatten the tensor [B,512]
        x = self.classifier(x)
        return x

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = Basnet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3,weight_decay=1e-4)
scheduler = StepLR(optimizer, step_size=7, gamma=0.1)
print(device)


In [None]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params}")

In [None]:
class EarlyStopping:
  def __init__(self,patience=7, delta=0):
    self.patience = patience
    self.delta = delta
    self.counter = 0
    self.best_score = None
    self.early_stop = False

  def __call__(self,val_loss):
    if self.best_score is None:
      self.best_score = val_loss
    elif val_loss > self.best_score + self.delta:
      self.counter += 1
      if self.counter >= self.patience:
        self.early_stop = True
    else:
      self.best_score = val_loss
      self.counter = 0

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=35):
    train_history = {"train_loss": [], "val_loss": [], "val_acc": []}
    stopper = EarlyStopping(patience=5)
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")

        print("-" * 20)

        model.train()
        running_loss = 0

        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)
        train_history["train_loss"].append(avg_train_loss)

        #validation
        model.eval()
        val_loss = 0
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                preds = outputs.argmax(1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        avg_val_loss = val_loss / len(val_loader)
        val_acc = correct / total

        train_history["val_loss"].append(avg_val_loss)
        train_history["val_acc"].append(val_acc)

        scheduler.step()

        print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val ACC: {val_acc:.4f}")

        stopper(avg_val_loss)
        if stopper.early_stop:
            print("Early stopping")
            break

    return model, train_history


In [None]:
model,train_history = train_model(model, train_loader, val_loader, criterion, optimizer, epochs=35)

In [None]:
#training visualization 
# training and validation loss visualization 
plt.figure(figsize=(12,5))

plt.subplot(1,2,2)
plt.plot(train_history['train_loss'], label="Train Loss", color='orange')
plt.plot(train_history['val_loss'], label="Validation Loss", color='green')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

# validation accuracy visualization
plt.subplot(1, 2, 2)
plt.plot(train_history['val_acc'], label='Val Acc')
plt.legend()
plt.title('Accuracy over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')

In [None]:
def evaluate_model(model,test_loader, criterion):
  model.eval()
  test_history = {"avg_test_loss": [], "test_acc": []}
  test_loss = 0.0
  correct = 0
  total = 0

  with torch.no_grad():
    for images,labels in test_loader:
      images = images.to(device)
      labels = labels.to(device)

      outputs = model(images)
      loss = criterion(outputs,labels)
      test_loss += loss.item()

      preds = outputs.argmax(dim=1)
      correct += (preds == labels).sum().item()
      total += labels.size(0)

  test_history["avg_test_loss"] = test_loss / len(test_loader)
  test_history["test_acc"] = correct / total

  print(f"Test Loss: {test_history['avg_test_loss']:.4f} | Test ACC: {test_history['test_acc']:.4f}")

  return test_history

In [None]:
test_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # ensure 1 channel
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5],
        std=[0.5]
    )
])

test_loader = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)


test_history = evaluate_model(
    model,
    test_loader,
    criterion
)


In [None]:
#visualizing test loss and test accuracy 
# training and validation loss visualization
plt.figure(figsize=(12,5))

plt.subplot(1,2,2)
plt.plot(train_history['train_loss'], label="Train Loss", color='blue')
plt.plot(test_history['avg_test_loss'], label="Test Loss", color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

# validation accuracy visualization 
plt.subplot(1, 2, 2)
plt.plot(test_history['test_acc'], label='Test Acc')
plt.legend()
plt.title('Accuracy over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')