# Rank Features

## Imports

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import models
from sklearn.model_selection import train_test_split
from sklearn.metrics import mutual_info_score
from skimage.feature import hog
from skimage import exposure
import cv2


## Util function

In [3]:
def create_model():
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 1)
    
    for param in model.parameters():
        param.requires_grad = False
    for param in model.fc.parameters():
        param.requires_grad = True
    
    return model

def cnn_feature_importance(train_loader, test_loader, device):
    model = create_model().to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)
    
    model.train()
    for epoch in range(10):
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.float().to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs.squeeze(), labels)
            loss.backward()
            optimizer.step()
    
    # Extract features
    model.eval()
    feature_maps = []
    with torch.no_grad():
        for inputs, _ in test_loader:
            inputs = inputs.to(device)
            features = model.layer4(model.layer3(model.layer2(model.layer1(model.maxpool(model.relu(model.bn1(model.conv1(inputs))))))))
            feature_maps.append(features.cpu().numpy())
    
    feature_maps = np.concatenate(feature_maps, axis=0)
    importance = np.mean(feature_maps, axis=(0, 2, 3))
    return importance

def mutual_information_ranking(train_loader):
    all_images = []
    all_labels = []
    for inputs, labels in train_loader:
        all_images.append(inputs.numpy())
        all_labels.append(labels.numpy())
    
    X = np.concatenate(all_images, axis=0)
    y = np.concatenate(all_labels, axis=0)
    
    X_flat = X.reshape(X.shape[0], -1)
    mi_scores = [mutual_info_score(X_flat[:, i], y) for i in range(X_flat.shape[1])]
    return np.array(mi_scores)

def hog_feature_importance(train_loader):
    hog_features = []
    for inputs, _ in train_loader:
        for img in inputs:
            img = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
            gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            fd, hog_image = hog(gray, orientations=8, pixels_per_cell=(16, 16),
                                cells_per_block=(1, 1), visualize=True)
            hog_features.append(hog_image)
    
    importance = np.mean(np.array(hog_features), axis=0)
    return importance.flatten()

def plot_feature_importance(importance, title):
    plt.figure(figsize=(10, 6))
    plt.bar(range(len(importance)), importance)
    plt.title(title)
    plt.xlabel('Feature Index')
    plt.ylabel('Importance')
    plt.tight_layout()
    plt.show()