In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pickle

In [None]:
# Swin Transformer block
class SwinTransformerBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_heads=1):
        super(SwinTransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=in_channels, num_heads=num_heads)
        self.linear1 = nn.Linear(in_channels, out_channels)
        self.linear2 = nn.Linear(out_channels, out_channels)
        self.norm1 = nn.LayerNorm(out_channels)
        self.norm2 = nn.LayerNorm(out_channels)

    def forward(self, x):
        x = x.view(x.size(0), -1, x.size(1))
        attn_output, _ = self.attention(x, x, x)
        x = self.linear1(attn_output)
        x = self.norm1(x)
        x = nn.functional.relu(x)
        x = self.linear2(x)
        x = self.norm2(x)
        x = nn.functional.relu(x)
        x = x.view(x.size(0), x.size(2), int(x.size(1) ** 0.5), int(x.size(1) ** 0.5))
        return x

# ConvNeXt block
class ConvNeXtBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvNeXtBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.norm = nn.LayerNorm(out_channels)
        self.linear = nn.Linear(out_channels, out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = x.permute(0, 2, 3, 1)
        B, H, W, C = x.shape
        x = x.contiguous().view(B * H * W, C)
        x = self.norm(x)
        x = x.view(B, H, W, C).permute(0, 3, 1, 2)
        x = nn.functional.relu(x)
        return x

# FusionModel
class FusionModel(nn.Module):
    def __init__(self):
        super(FusionModel, self).__init__()
        self.swin_transformer_block1 = SwinTransformerBlock(in_channels=3, out_channels=64)
        self.convnext_block1 = ConvNeXtBlock(in_channels=3, out_channels=64)
        self.spatial_attention_mechanism = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1)
        self.classifier = nn.Linear(224 * 224, 12)

    def forward(self, x):
        swin_output = self.swin_transformer_block1(x)
        convnext_output = self.convnext_block1(x)
        combined_output = torch.cat((swin_output, convnext_output), dim=1)
        attention_output = self.spatial_attention_mechanism(combined_output)
        attention_output = attention_output.view(attention_output.size(0), -1)
        output = self.classifier(attention_output)
        return output


In [None]:
# Load the model
model_path = "Fusion Model.pkl"
with open(model_path, 'rb') as f:
    model = pickle.load(f)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

In [None]:
# Model Usage
def predict(model, image_path, transform):
    model.eval()
    image = Image.open(image_path).convert('RGB')
    transformed_image = transform(image)
    
    # Save the transformed image
    save_path = 'transformed_image.png'
    transforms.ToPILImage()(transformed_image).save(save_path)
    
    transformed_image = transformed_image.unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(transformed_image)
        _, predicted = torch.max(output, 1)
    return predicted.item()

In [None]:
# Define the transform for the image
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
# Predictions
image_path = 'unseen_5.jpeg'
prediction = predict(model, image_path, transform)
print(f'Predicted class: {prediction}')