In [1]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import os

model = models.resnet50(pretrained=True)
model = torch.nn.Sequential(*(list(model.children())[:-1])) 

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def extract_features(image_path, model, transform):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)
    with torch.no_grad():
        features = model(image)
    return features.squeeze()

all_features = []
all_labels = []
image_paths = os.listdir('./graphs')

class_to_idx = {label: idx for idx, label in enumerate(set(image_paths))}
labels = [class_to_idx[label] for label in image_paths]

for image_path in image_paths:
    features = extract_features('./graphs/' + image_path, model, transform)
    all_features.append(features)

all_features = np.array(all_features)


  from .autonotebook import tqdm as notebook_tqdm
  all_features = np.array(all_features)
  all_features = np.array(all_features)


In [2]:
all_features = np.array([tensor.numpy() for tensor in all_features])

In [3]:
import torch.nn as nn
import torch

class CrossAttention(nn.Module):
    def __init__(self, feature_dim, num_heads):
        super(CrossAttention, self).__init__()
        self.query = nn.Linear(feature_dim, feature_dim)
        self.key = nn.Linear(feature_dim, feature_dim)
        self.value = nn.Linear(feature_dim, feature_dim)
        self.num_heads = num_heads
        self.attention = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=num_heads)

    def forward(self, features):
        query = self.query(features)
        key = self.key(features)
        value = self.value(features)
        
        query = query.unsqueeze(1).transpose(0, 1)
        key = key.unsqueeze(1).transpose(0, 1)
        value = value.unsqueeze(1).transpose(0, 1)
        
        attended_features, _ = self.attention(query, key, value)
        return attended_features.squeeze(0)

In [4]:
class ImageClassifier(nn.Module):
    def __init__(self, feature_dim, num_classes, num_heads):
        super(ImageClassifier, self).__init__()
        self.cross_attention = CrossAttention(feature_dim, num_heads)
        self.fc = nn.Linear(feature_dim, num_classes)

    def forward(self, features):
        attended_features = self.cross_attention(features)
        logits = self.fc(attended_features)
        return logits

feature_dim = all_features.shape[1]
num_classes = len(class_to_idx)
num_heads = 8
classifier = ImageClassifier(feature_dim, num_classes, num_heads)

features_tensor = torch.tensor(all_features, dtype=torch.float32)
labels_tensor = torch.tensor(labels, dtype=torch.long)

In [5]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=0.001)

num_epochs = 500
for epoch in range(num_epochs):
    classifier.train()
    optimizer.zero_grad()
    outputs = classifier(features_tensor)
    loss = criterion(outputs, labels_tensor)
    loss.backward()
    optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

Epoch [1/500], Loss: 2.7739
Epoch [2/500], Loss: 5.8773
Epoch [3/500], Loss: 7.3979
Epoch [4/500], Loss: 4.4566
Epoch [5/500], Loss: 30.2278
Epoch [6/500], Loss: 27.8231
Epoch [7/500], Loss: 21.1104
Epoch [8/500], Loss: 28.6380
Epoch [9/500], Loss: 37.0786
Epoch [10/500], Loss: 24.6308
Epoch [11/500], Loss: 25.1410
Epoch [12/500], Loss: 15.4224
Epoch [13/500], Loss: 17.6662
Epoch [14/500], Loss: 14.6495
Epoch [15/500], Loss: 10.1206
Epoch [16/500], Loss: 18.2036
Epoch [17/500], Loss: 13.0870
Epoch [18/500], Loss: 8.9252
Epoch [19/500], Loss: 5.1947
Epoch [20/500], Loss: 5.4625
Epoch [21/500], Loss: 7.3491
Epoch [22/500], Loss: 11.1725
Epoch [23/500], Loss: 6.8633
Epoch [24/500], Loss: 6.5432
Epoch [25/500], Loss: 7.9515
Epoch [26/500], Loss: 8.0273
Epoch [27/500], Loss: 6.7042
Epoch [28/500], Loss: 5.1246
Epoch [29/500], Loss: 5.1584
Epoch [30/500], Loss: 4.4987
Epoch [31/500], Loss: 7.0125
Epoch [32/500], Loss: 6.2976
Epoch [33/500], Loss: 4.4686
Epoch [34/500], Loss: 4.6787
Epoch [35

In [6]:
def evaluate(classifier, features, labels):
    classifier.eval()
    with torch.no_grad():
        outputs = classifier(features)
        _, predicted = torch.max(outputs.data, 1)
        accuracy = (predicted == labels).sum().item() / len(labels)
    return accuracy

accuracy = evaluate(classifier, features_tensor, labels_tensor)
print(f'Accuracy: {accuracy * 100:.2f}%')

Accuracy: 93.75%
