In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.models as models
from collections import OrderedDict
from torch.autograd import Variable
import pandas as pd

import time
import pickle
import pandas
import numpy as np 
import matplotlib.pyplot as plt
import scipy.stats as st
import gc
import random

import seaborn as sns
import matplotlib.pyplot as plt

import sys
import os

import tqdm
import cv2
from typing import Callable, List, Optional, Tuple
import torchvision.transforms as transforms
import ttach as tta


from torchvision import transforms



In [None]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(DEVICE)

In [None]:
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
#A) main functions definition
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 

class ResNet_18(nn.Module):
    def __init__(self, output_classes=3):
        super(ResNet_18, self).__init__()
        self.resnet_model = models.resnet18(pretrained=True)
        self.resnet_model.fc = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(512, 128)), ('relu', nn.ReLU()),
            ('fc2', nn.Linear(128, output_classes))
        ]))
    
    def forward(self, x, apply_sigmoid=False):
        logits = self.resnet_model(x)
        if apply_sigmoid:
            return torch.sigmoid(logits)
        return logits

def fit_model(model, X_data, y_data, EPOCHS=5, BATCH_SIZE=32):
    optimizer = torch.optim.Adam(model.parameters())
    error = nn.BCEWithLogitsLoss()  

    model.train()
    n = X_data.shape[0]

    for epoch in range(EPOCHS):
        obsIDs = np.arange(X_data.shape[0])
        np.random.shuffle(obsIDs)
        
        epoch_loss = 0.0
        correct_predictions = 0
        total_samples = 0
        
        for batch_start in range(0, n, BATCH_SIZE):
            if batch_start + BATCH_SIZE > n:
                break

            Curr_obsIDs = obsIDs[batch_start:batch_start + BATCH_SIZE]
            var_X_batch = X_data[Curr_obsIDs,:,:,:].float().to(DEVICE)
            var_y_batch = y_data[Curr_obsIDs,:].float().to(DEVICE)  
            optimizer.zero_grad()
            output = model(var_X_batch)  
            
            loss = error(output, var_y_batch)  
            epoch_loss += loss.item()
            
            probabilities = torch.sigmoid(output)  
            predictions = (probabilities > 0.5).float()  
            
            correct_predictions += (predictions == var_y_batch).sum().item()
            total_samples += var_y_batch.numel()  

            loss.backward()
            optimizer.step()

        epoch_accuracy = (correct_predictions / total_samples) * 100  

        print(f'Epoch [{epoch+1}/{EPOCHS}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%')

    torch.save(model.state_dict(), "./modelFinal.pytorchModel")





def LargeDatasetPred(model, var_X, BlockSizes, DEVICE='cpu'):
    """
    Prediction of large datasets (incrementally predicts with 'model' blocks of observations out of 'var_X' having a size of 'BlockSizes').
    To be used in case the RAM is not large enough to treat all information in the NN.
    """
    n_loc = var_X.shape[0]

    loc_miniBatch_Start = 0

    while loc_miniBatch_Start < n_loc:
        # Define the mini-batch domain
        loc_miniBatch_End = loc_miniBatch_Start + BlockSizes
        if loc_miniBatch_End >= n_loc:
            loc_miniBatch_End = n_loc

        # Local prediction
        with torch.no_grad():
            loc_predY = model(var_X[loc_miniBatch_Start:loc_miniBatch_End,:,:,:].to(DEVICE)).to('cpu')

        # Merge local prediction with former ones
        if loc_miniBatch_Start == 0:
            all_predY = torch.clone(loc_predY)
        else:
            all_predY = torch.cat([all_predY, loc_predY], dim=0)

        # Increment loc_miniBatch_Start
        loc_miniBatch_Start += BlockSizes

    return all_predY
    

def showRGBImage(LodID,X,Y,S,M):
  """
  show observations
  """
  LocImage=(X[LodID,:,:,:]*255).astype(int)
  LocTitle='Y='+str(int(Y[LodID,0]))+'    E='+str(int(S[LodID,0])) +  '     S='+str(int(M[LodID,0]))
  plt.figure() 
  plt.imshow(LocImage)
  plt.title(LocTitle)
  plt.show()

  





In [None]:
Info_all_train = pd.read_csv('./DATA_celebA/train.csv')

Y = Info_all_train.astype(np.float32)

print(np.shape(Y))

# Compute the correlation matrix
corr_matrix = Y.corr()

# Visualize the correlation matrix
plt.figure(figsize=(10, 8))
sns.heatmap(corr_matrix, cmap='coolwarm', linewidths=.05)
plt.title('Correlation Matrix')
plt.show()

In [None]:
# Flatten the correlation matrix with stack and reset index
corr_flat = corr_matrix.stack().reset_index()
corr_flat.columns = ['Variable1', 'Variable2', 'Correlation']

# Remove self-correlation and duplicate pairs
corr_flat = corr_flat[corr_flat['Variable1'] != corr_flat['Variable2']]
corr_flat['abs_correlation'] = corr_flat['Correlation'].abs()
corr_flat = corr_flat.drop_duplicates(subset=['abs_correlation'])

# Get the top 10 most correlated pairs
top_correlations = corr_flat.sort_values(by='abs_correlation', ascending=False).head(100)

print(top_correlations[['Variable1', 'Variable2', 'Correlation']])

In [None]:
print(Info_all_train.columns.tolist())

In [None]:
# Training set
with open('./DATA_celebA/train_64x64.pkl', 'rb') as infile:
    X_train = pickle.load(infile)

Info_all_train = pd.read_csv('./DATA_celebA/train.csv')

#labels = ['Young', 'Eyeglasses', 'Smiling']

labels = Info_all_train.columns.tolist()
l = len(labels)
Y_train = Info_all_train[labels].values.astype(np.float32).reshape(-1, l)

#Test set
with open('./DATA_celebA/test_64x64.pkl', 'rb') as infile:
    X_test = pickle.load(infile)

Info_all_test = pd.read_csv('./DATA_celebA/test.csv')
Y_test = Info_all_test[labels].values.astype(np.float32).reshape(-1, l)




In [None]:
# Get 20 random indices
random_indices = random.sample(range(X_train.shape[0]), 50)

# Save images without labels
for idx in random_indices:
    img = X_train[idx]  # Get the image
    img = (img * 255).astype(np.uint8)  # Scale the image back to [0, 255] if needed
    img_filename = f'image_{idx}.png'  # Create a filename for saving
    cv2.imwrite(img_filename, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))  # Save the image
    print(idx)
    showRGBImage(idx, X_train, Y_train[:,0].reshape(-1,1) , Y_train[:,1].reshape(-1,1), Y_train[:,2].reshape(-1,1))

In [None]:
#convert to torch format
torch_X_train = torch.from_numpy(X_train[:, :, :, :]).type(torch.FloatTensor).transpose(1, 3).to(DEVICE)
torch_y_train = torch.from_numpy(Y_train).type(torch.FloatTensor).to(DEVICE)

torch_X_test = torch.from_numpy(X_test[:, :, :, :]).type(torch.FloatTensor).transpose(1, 3).to(DEVICE)
torch_y_test = torch.from_numpy(Y_test).type(torch.FloatTensor).to(DEVICE)

In [None]:
# C) Training phase
model = ResNet_18(l).to(DEVICE)
EPOCHS_in = 50
BATCH_SIZE_in = 512
fit_model(model, torch_X_train, torch_y_train, EPOCHS=EPOCHS_in, BATCH_SIZE=BATCH_SIZE_in)

In [None]:
model_test = ResNet_18(l)  
model_test.load_state_dict(torch.load("./modelFinal.pytorchModel"))
model_test.to(DEVICE)  


In [None]:
model_test.eval()

#model_test.cpu() 
X_test = torch_X_test
y_test = torch_y_test.cpu()

model_test.eval()  
with torch.no_grad():
    predY_test = model_test(X_test[:1024, :, :, :]).cpu()
    
pred = (predY_test > 0.5).float()

error = nn.BCELoss()
loss = error(pred, y_test[:1024])

print('Loss (test data):', loss.item())

print('Loss (test data): ' + str(loss.item()))
print('---------------------------------------------------------------------------------------------------------------------------')
# Compute predictions and accuracy
accuracy_Y = (pred[:1024,0] == y_test[:1024,0]).float().mean()
accuracy_E = (pred[:1024,1] == y_test[:1024,1]).float().mean()
accuracy_S = (pred[:1024,2] == y_test[:1024,2]).float().mean()

print(f'Accuracy of Arched_Eyebrows (test data): {accuracy_Y.item()*100}%')
plt.figure(figsize=(10, 6))
plt.plot(predY_test[:,1].cpu(), y_test[:1024,0].cpu(), 'x', alpha=0.2, label='Young')
plt.ylabel('True')
plt.xlabel('Pred')
plt.legend()
plt.show()

print(f'Accuracy of Attractive (test data): {accuracy_E.item()*100}%')

# Plot for 'Eyeglasses'
plt.figure(figsize=(10, 6))
plt.plot(predY_test[:,2].cpu(), y_test[:1024,1].cpu(), 'o', alpha=0.2, label='Eyeglasses')
plt.ylabel('True')
plt.xlabel('Pred')
plt.legend()
plt.show()

print(f'Accuracy of Bags_Under_Eyes (test data): {accuracy_S.item()*100}%')

# Plot for 'Smiling'
plt.figure(figsize=(10, 6))
plt.plot(predY_test[:,3].cpu(), y_test[:1024,2].cpu(), 's', alpha=0.2, label='Smiling')
plt.ylabel('True')
plt.xlabel('Pred')
plt.legend()
plt.show()


#                  GradCam

In [None]:
path_to_module = os.path.join(os.getcwd(), 'pytorch-grad-cam')
sys.path.append(path_to_module)


from pytorch_grad_cam.grad_cam import GradCAM
from pytorch_grad_cam.hirescam import HiResCAM
from pytorch_grad_cam.grad_cam_elementwise import GradCAMElementWise
from pytorch_grad_cam.ablation_layer import AblationLayer, AblationLayerVit, AblationLayerFasterRCNN
from pytorch_grad_cam.ablation_cam import AblationCAM
from pytorch_grad_cam.xgrad_cam import XGradCAM
from pytorch_grad_cam.grad_cam_plusplus import GradCAMPlusPlus
from pytorch_grad_cam.score_cam import ScoreCAM
from pytorch_grad_cam.layer_cam import LayerCAM
from pytorch_grad_cam.eigen_cam import EigenCAM
from pytorch_grad_cam.eigen_grad_cam import EigenGradCAM
from pytorch_grad_cam.random_cam import RandomCAM
from pytorch_grad_cam.fullgrad_cam import FullGrad
from pytorch_grad_cam.guided_backprop import GuidedBackpropReLUModel
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
from pytorch_grad_cam.feature_factorization.deep_feature_factorization import DeepFeatureFactorization, run_dff_on_image
import pytorch_grad_cam.utils.model_targets
import pytorch_grad_cam.utils.reshape_transforms
import pytorch_grad_cam.metrics.cam_mult_image
import pytorch_grad_cam.metrics.road
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.base_cam import BaseCAM


In [None]:
class GradCAM(BaseCAM):
    def __init__(self, model, target_layers, reshape_transform=None):
        super(GradCAM, self).__init__(model, target_layers, reshape_transform)

    def get_cam_weights(self, input_tensor, target_layer, target_category, activations, grads):
        # 2D image
        if len(grads.shape) == 4:
            return np.mean(grads, axis=(2, 3))
        
        # 3D image
        elif len(grads.shape) == 5:
            return np.mean(grads, axis=(2, 3, 4))
        
        else:
            raise ValueError("Invalid grads shape. Shape of grads should be 4 (2D image) or 5 (3D image).")


In [None]:
model_test.eval()
target_layers = [model_test.resnet_model.layer4[-1]]  # Accessing the correct layer

# Initialize Grad-CAM
cam = GradCAM(model=model_test, target_layers=target_layers)

# Transform for input image
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Function to visualize Grad-CAM
def visualize_gradcam(example_image, target_category):
    example_image = transform(example_image).unsqueeze(0).to(DEVICE).float()  # Ensure using float32

    # Generate CAM
    grayscale_cam = cam(input_tensor=example_image, targets=[ClassifierOutputTarget(target_category)])[0, :]

    # Convert image to numpy for visualization
    example_image_np = example_image.squeeze().cpu().numpy().transpose((1, 2, 0))
    example_image_np = example_image_np * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]
    example_image_np = np.clip(example_image_np, 0, 1)

    # Visualize CAM
    cam_image = show_cam_on_image(example_image_np, grayscale_cam, use_rgb=True)
    plt.imshow(cam_image)
    plt.title(f'Grad-CAM for class {labels[target_category]}')
    plt.show()

# Example images for visualization
for j in random.sample(range(X_train.shape[0]), 20):
    for i, label in enumerate(labels):
        print(f'Visualizing Grad-CAM for class: {label}')
        visualize_gradcam(X_train[j], i)  # Visualize for each class (0, 1, 2)