In [None]:
class CreatePatches(nn.Module):
  def __init__(self, channels, patch_size, embed_dim):
    super().__init__()
    self.patch = nn.Conv2d(
            in_channels=channels,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
  def forward(self, x):
        patches = self.patch(x).transpose(1, 2)
        patches = patches.transpose(2, 3)
        return patches

In [None]:
class LayerNormalization(nn.Module):
    def __init__(self, embed_dim, epsilon=1e-5):
        super().__init__()

        self.epsilon = epsilon
        self.pre_norm = nn.LayerNorm(embed_dim, eps=epsilon)

    def forward(self, image):
        # Reshape the image to apply LayerNorm
        b, c, h, w = image.shape
        #print(image.shape)
        image_reshaped = image.reshape(b,-1,768)
        #print(image_reshaped.shape)

        x_norm = self.pre_norm(image_reshaped)

        # Reshape the normalized image back to its original shape
        x_norm = x_norm.reshape(b, h, w,c)
        x_norm = x_norm.transpose(2,3)
        #print(x_norm.shape)

        return x_norm

In [None]:
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_planes, in_planes // ratio),
            nn.ReLU(),
            nn.Linear(in_planes // ratio, in_planes)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x).view(x.size(0), -1))
        max_out = self.fc(self.max_pool(x).view(x.size(0), -1))
        out = avg_out + max_out
        return self.sigmoid(out).view(x.size(0), -1, 1, 1)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size//2)

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(out)
        return torch.sigmoid(out)

class CBAM(nn.Module):
    def __init__(self, in_planes):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(in_planes)
        self.sa = SpatialAttention()

    def forward(self, x):
        out = x * self.ca(x)
        out = out * self.sa(out)
        #print(out.shape)
        return out

In [None]:
class Transformers1(nn.Module):
      def __init__(self, embed_dim, hidden_dims, epsilon, ratio, kernel_size = 2, dropout=0.0):
        super().__init__()

        self.pre_norm = (LayerNormalization(embed_dim, epsilon=1e-6))
        self.attention = (CBAM(embed_dim))
        self.norm = nn.LayerNorm(embed_dim, epsilon)
        self.MLP = nn.Sequential(
            nn.Linear(embed_dim, hidden_dims),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dims, embed_dim),
            nn.Dropout(dropout)
        )
      def forward(self, x):
        b,h,w,c = x.shape
        #print(x.shape)
        x_norm = self.pre_norm(x)
        x_norm = rearrange(x, 'b h w c -> b c h w')
        #print(x_norm.shape)
        x_att = self.attention(x_norm)
        #print(x_att.shape)
        x_re = x_att.transpose(1,2)
        x_re = x_re.transpose(2,3)
        x = x + x_re

        x_red = x.reshape(-1,768)

        x_red = x_red + self.MLP(self.norm(x_red))
        x_red = x_red.view(*x.size())
        #print(x_red.shape)
        return x_red

In [None]:
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, n_layers):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.lstm = nn.LSTM(input_size, hidden_size, n_layers, batch_first=True)
    def forward(self, x):
        h0 = torch.zeros(self.n_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.n_layers, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.lstm(x, (h0, c0))
        return out

In [None]:
class ViTLSTM(nn.Module):
    def __init__(
        self, d_model, max_len,epsilon = 1e-05,
        img_size=192,
        in_channels=3,
        patch_size=16,
        embed_dim=768,
        hidden_dims=3072,
        num_layers=12,
        dropout=0.1,
        num_classes=2,
        ratio = 16,
        kernel_size = 7, hidden_size = 1024, input_size = 768, n_layers = 2
    ):
        super().__init__()
        self.patch_size = patch_size
        num_patches = (img_size//patch_size) ** 2
        self.patches = CreatePatches(
            channels=in_channels,
            embed_dim=embed_dim,
            patch_size=patch_size
        )

        self.position_embedding = PositionalEncoding(max_len,d_model)
        self.attn_layers = nn.ModuleList([])
        for _ in range(num_layers):
            self.attn_layers.append(
                Transformers1(embed_dim, hidden_dims,  ratio, kernel_size, dropout, epsilon)
            )
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(embed_dim, eps=1e-06)
        self.lstm1 = LSTMModel(input_size, hidden_size, n_layers)


        self.flat = nn.Flatten(1)
        self.fc1 = nn.Linear(147456, 512)
        self.relu1 = nn.LeakyReLU(0.1)
        self.bn1 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 128)
        self.relu2 = nn.LeakyReLU(0.1)
        self.drop = nn.Dropout(0.5)
        self.fc3 = nn.Linear(128, 16)
        self.relu3 = nn.LeakyReLU(0.1)
        self.fc5 = nn.Linear(16, 2)
        self.relu5 = nn.Sigmoid()

    def forward(self, x):
        x = self.patches(x)
        #print("Create_Patch: ", x.shape)
        b,h,w,c = x.shape
        self.max_len = h*w
        x = x.reshape(b, -1, 768)
        x = self.position_embedding(x)
        x = self.dropout(x)
        x = x.reshape(b,h,w,c)
        for layer in self.attn_layers:
            x = layer(x)
        x = self.ln(x)
        x = x.reshape(b, -1, 768)

        x_l = self.lstm1(x)
        #print(x.shape)
        x = self.flat(x_l)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.bn1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.drop(x)
        x = self.fc3(x)
        x = self.relu3(x)
        x = self.fc5(x)
        x = self.relu5(x)
        return x

In [None]:
import math
from torch.optim.rmsprop import RMSprop
from torch.optim import Adam, SGD,RMSprop
model = ViTLSTM( max_len = 144,d_model = 768,
                epsilon = 1e-05,
        in_channels = 3,
        img_size=192,
        patch_size=16,
        embed_dim=768,
        hidden_dims=3072,
        num_layers=12,
        dropout=0.1,
        num_classes=2,
        ratio = 16,
        kernel_size = 7, hidden_size = 1024, input_size = 768, n_layers = 2
    ).to(device)
print(model)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = SGD(model.parameters(),momentum = 0.9, lr = 0.001)

ViTLSTM(
  (patches): CreatePatches(
    (patch): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (position_embedding): PositionalEncoding()
  (attn_layers): ModuleList(
    (0-11): 12 x Transformers1(
      (pre_norm): LayerNormalization(
        (pre_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      )
      (attention): CBAM(
        (ca): ChannelAttention(
          (avg_pool): AdaptiveAvgPool2d(output_size=1)
          (max_pool): AdaptiveMaxPool2d(output_size=1)
          (fc): Sequential(
            (0): Linear(in_features=768, out_features=48, bias=True)
            (1): ReLU()
            (2): Linear(in_features=48, out_features=768, bias=True)
          )
          (sigmoid): Sigmoid()
        )
        (sa): SpatialAttention(
          (conv): Conv2d(2, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
        )
      )
      (norm): LayerNorm((768,), eps=16, elementwise_affine=True)
      (MLP): Sequential(
        (0): Linear(in_features=76

In [None]:
# Training loop
nb_epoch = 25
for epoch in range(nb_epoch):
    model.train()
    train_running_loss = 0.0
    val_running_loss = 0.0
    #val_running_loss = 0.0
    correct_train = 0
    total_train = 0

    for i, (images, labels) in enumerate(train_dl):
        images = images.to(device)
        labels = labels.to(device)
        #print("train: ",i)


        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_running_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

    train_accuracy = correct_train / total_train
    # Calculate average loss for the epoch
    train_average_loss = train_running_loss / len(train_dl)
    #print("Train done")


    # Validation loop
    model.eval()
    correct_val = 0
    total_val = 0

    with torch.no_grad():
        for i, (images, labels) in enumerate(val_dl):

            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)


            val_running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()

    # Calculate validation accuracy
    val_accuracy = correct_val / total_val
    # Calculate average loss for the epoch
    val_average_loss = val_running_loss / len(val_dl)

    print(f"Epoch {epoch+1} -- Train Loss: {train_average_loss:.2f} -- Train Accuracy: {train_accuracy:.2f} -- Validation Loss: {val_average_loss:.2f} -- Validation Accuracy: {val_accuracy:.2f} ")

# Testing loop
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for i, (images, labels) in enumerate(test_dl):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

# Calculate test accuracy
accuracy = correct / total

print(f"Test Accuracy: {accuracy:.2f}")

In [None]:
from sklearn.metrics import precision_score
with torch.no_grad():
    model.eval()
    correct = 0
    total = 0
    for images, labels in test_dl:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).float().sum().item()

        print(f"Actual Value : {labels}  Predicted Value : {predicted}")

    print("\n")
    print("Accuracy of the network on the {} test images: {} %".format(total, 100 * correct / total))
    # print(precision)

Actual Value : tensor([0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1,
        0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
        1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1,
        0, 0, 0], device='cuda:0')  Predicted Value : tensor([0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1,
        0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
        1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1,
        0, 0, 0], device='cuda:0')
Actual Value : tensor([1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
        1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0,
        0, 0, 1], device='cuda:0')  Predicted Value : tensor([1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
        1, 1, 1, 0, 1, 0, 1