# Imports

In [1]:
from utils.imports import *
from utils.data_loader import download_data, load_data
from utils.utils import preprocess_images, get_current_time, calculate_pca, apply_pca_to_rois, GaborPyramid
from utils.config import batch_size, num_epochs, model_str

# Initialize data

In [2]:
fnames = ["../kay_labels.npy", "../kay_labels_val.npy", "../kay_images.npz"]
urls = ["https://osf.io/r638s/download",
        "https://osf.io/yqb3e/download",
        "https://osf.io/ymnjv/download"]

if download_data(fnames, urls):
    init_training_inputs, init_test_inputs, training_outputs, test_outputs, roi, roi_names, labels, val_labels = load_data('../kay_images.npz')

# Gabor Wavelet Pyramid

In [3]:
import torch
import torch.nn as nn
import numpy as np
from torch.nn.parameter import Parameter

class GaborWaveletPyramid(nn.Module):
    def __init__(self, max_cycles_per_fov, num_orientations, num_phases, image_resolution, output_size=8428):
        super(GaborWaveletPyramid, self).__init__()
       
        
        self.max_cycles_per_fov = max_cycles_per_fov
        self.num_orientations = num_orientations
        self.num_phases = num_phases
        self.image_resolution = image_resolution
        self.spatial_frequencies = 2 ** np.arange(np.log2(max_cycles_per_fov))
        
        # Define parameters of the model
        self.num_wavelets = len(self.spatial_frequencies) * num_orientations * num_phases
        self.kernel = Parameter(torch.zeros(self.num_wavelets))  # The weights for each wavelet
        self.dc_offset = Parameter(torch.zeros(1))  # The DC offset

        # Generate Gabor wavelets
        self.gabor_wavelets = self.generate_gabor_wavelets(max_cycles_per_fov, num_orientations, num_phases, image_resolution)

        #visualize_gabor_wavelets(self.gabor_wavelets)

        self.fc = nn.Linear(self.num_wavelets, output_size)
        self.precomputed = None

    
    def generate_gabor_wavelets(self, max_cycles_per_fov, num_orientations, num_phases, image_resolution):

        # Initialize the Gabor wavelet tensor
        spatial_frequencies = 2 ** np.arange(np.log2(max_cycles_per_fov))
        num_wavelets = len(spatial_frequencies) * num_orientations * num_phases
        print("num_wavelets: ")
        print(num_wavelets)
        gabor_wavelets = torch.zeros(num_wavelets, image_resolution, image_resolution)
        
        # Gabor wavelet parameters
        sigma = lambda f: 0.56 / f
        wavelet_size = lambda f: np.ceil(sigma(f) * 2.5) * 2 + 1  # To ensure an odd-sized filter
        
        # Generate the wavelets
        count = 0
        for freq_idx, freq in enumerate(spatial_frequencies):
            for orientation in np.linspace(0, np.pi, num_orientations, endpoint=False):
                for phase in np.linspace(0, 2*np.pi, num_phases, endpoint=False):
                    # Calculate the wavelet size
                    filter_sigma = sigma(freq)
                    size = np.ceil(filter_sigma * 2.5) * 2 + 1
                    size = int(size) if int(size) % 2 == 1 else int(size) + 1  # Ensure the size is odd

                    # Generate the Gabor wavelet
                    x, y = np.meshgrid(
                        np.linspace(-size//2, size//2, size),
                        np.linspace(-size//2, size//2, size)
                    )
                    rotx = x * np.cos(orientation) + y * np.sin(orientation)
                    roty = -x * np.sin(orientation) + y * np.cos(orientation)
                    
                    gabor = np.exp(-(rotx**2 + roty**2) / (2 * sigma(freq)**2)) * np.cos(2 * np.pi * freq * rotx / image_resolution + phase)
                    
                    # Normalize to zero mean and unit variance
                    gabor -= gabor.mean()
                    gabor /= gabor.std()

                    # Calculate the indices for placement
                    half_size = size // 2
                    center = image_resolution // 2
                    wavelet_slice = slice(center - half_size, center + half_size + 1)
                
                    # Place the generated wavelet in the tensor
                    gabor_wavelets[count, wavelet_slice, wavelet_slice] = torch.from_numpy(gabor)
                    count += 1
                    
        return gabor_wavelets
    def compute_gabor_responses(self, images):
        """Precompute Gabor responses for a batch of images."""
        with torch.no_grad():
            batch_responses = []
            n = 0
            for image in images:
                n+=1
                #print "n / totalimages loaded"
                print(n, "/", len(images))
                image = image.unsqueeze(0) if image.dim() == 2 else image
                contrast_energy = torch.zeros(image.size(0), self.num_wavelets, device=image.device)
                for i, wavelet in enumerate(self.gabor_wavelets):
                    response = nn.functional.conv2d(image, wavelet.unsqueeze(0).unsqueeze(0), padding='same')
                    energy = response**2
                    contrast_energy[:, i] = energy.view(energy.size(0), -1).sum(dim=1)
                batch_responses.append(contrast_energy.sqrt())
            return torch.cat(batch_responses)
        
    def compute_and_save_gabor_responses(self, images, file_path):
        """Precompute Gabor responses for a batch of images and save to a file."""
        if isinstance(images, np.ndarray):
            images = torch.from_numpy(images)
        responses = self.compute_gabor_responses(images)
        torch.save(responses, file_path)
        print(f"Saved Gabor responses to {file_path}")

    def load_precomputed_gabor_responses(self, file_path):
        """Load precomputed Gabor responses from a file."""
        self.precomputed = torch.load(file_path)
        print(f"Loaded Gabor responses from {file_path}")
        
    def forward(self, image):

        contrast_energy = self.precomputed if self.precomputed is not None else self.compute_gabor_responses(image)

        output = self.fc(contrast_energy)

        return output

# Example instantiation of the model
# These values should be determined based on the specifics of the experiment and data
max_cycles_per_fov = 16
num_orientations = 8
num_phases = 2
image_resolution = 64  # Example resolution, actual value should be based on the data

# Create the Gabor Wavelet Pyramid model
gwp_model = GaborWaveletPyramid(max_cycles_per_fov, num_orientations, num_phases, image_resolution)

print(gwp_model)

#gwp_model.forward(init_test_inputs[0])

init_test_inputs.shape

num_wavelets: 
64
GaborWaveletPyramid(
  (fc): Linear(in_features=64, out_features=8428, bias=True)
)


(120, 128, 128)

In [4]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, inputs, outputs):
        self.inputs = torch.from_numpy(inputs).float().unsqueeze(1)  # Add channel dimension
        self.outputs = torch.from_numpy(outputs).float()

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.outputs[idx]
#first 2 training inputs
train_inputs2 = init_training_inputs[:10]
train_outputs2 = training_outputs[:10]


In [None]:
import torch
from torch.utils.data import random_split, DataLoader
from copy import deepcopy

def plot_losses(train_losses, val_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Losses')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

# Define your dataset (assuming MyDataset is your custom dataset class)
full_dataset = MyDataset(init_training_inputs, training_outputs)

# Split your dataset into training and validation sets
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Create data loaders for training and validation sets
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# Define a simple training loop with early stopping
def train_model_with_early_stopping(model, train_loader, val_loader, epochs, learning_rate, patience):
    criterion = nn.MSELoss()  # Mean Squared Error Loss for regression tasks
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    best_loss = float('inf')
    best_model = None
    epochs_no_improve = 0
    train_losses, val_losses = [], []

    for epoch in range(epochs):
        # Training phase
        print("In epoch")
        model.train()
        running_loss = 0.0
        print("in train loop")
        for inputs, targets in train_loader:
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            print("in val loop")
            for inputs, targets in val_loader:
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()
        
        # Calculate average losses
        avg_train_loss = running_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)

        # Print training and validation loss
        print(f"Epoch {epoch+1}/{epochs} - Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")

        # Early stopping logic
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            best_model = deepcopy(model.state_dict())  # Save the best model
            epochs_no_improve = 0
            torch.save(best_model, f'../trained_models/gabor/val_model_bests_of_all.pth')
        else:
            epochs_no_improve += 1
            if epochs_no_improve == patience:
                print(f"Early stopping triggered after {epoch+1} epochs.")
                break
        if (epoch + 1) % 10 == 0:
            # Plot loss curve
            plt.figure(figsize=(10, 5))
            plt.title(f'{model_str} - Encoder Loss')
            plt.plot(train_losses, label="Encoder Loss")
            plt.plot(val_losses, label="Encoder Validation Loss")
            plt.xlabel("Iterations")
            plt.ylabel("Epoch")
            plt.legend()
            plt.show()
    
    # Load the best model
    model.load_state_dict(best_model)
    plot_losses(train_losses, val_losses)
    return model

# Initialize your model
gwp_model = GaborWaveletPyramid(max_cycles_per_fov=16, num_orientations=8, num_phases=2, image_resolution=128)

file_path = 'all_training_inputs.pt'

#gwp_model.compute_and_save_gabor_responses(init_training_inputs, file_path)

gwp_model.load_precomputed_gabor_responses(file_path)

# Train your model with early stopping
trained_model = train_model_with_early_stopping(
    gwp_model,
    train_loader,
    val_loader,
    epochs=5000,  # Number of epochs to train for
    learning_rate= 0.001, #0.001,
    patience=10  # Number of epochs to wait for improvement before stopping
)


In [17]:
# Try model predicting data from a dataloader
def predicted_actual_values(model, dataloader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    all_preds = []
    all_actual = []
    model.eval()
    with torch.no_grad():
        for data, targets in dataloader:
            # Ensure data is on the same device as the model
            data = data.to(device, dtype=torch.float)
            targets = targets.to(device, dtype=torch.float)

            #outputs, _ = model.forward_with_intermediate(data)
            outputs = model.forward(data)
            all_preds.extend(outputs.cpu().numpy())
            all_actual.extend(targets.cpu().numpy())

    return np.array(all_preds), np.array(all_actual)

# Compute Root Mean Squared Error
def compute_rmse(predictions, actual):
    rmse = mean_squared_error(actual, predictions, squared=False)
    return rmse

# Compute R-squared (Coefficient of Determination)
def compute_r2_score(predictions, actual):
    r2 = r2_score(actual, predictions)
    return r2

# Compute Pearson Correlation Coefficient
def compute_pearson_correlation(predictions, actual):
    # Flatten the predictions and actual arrays in case they have more than one dimension
    predictions_flat = predictions.flatten()
    actual_flat = actual.flatten()
    correlation, _ = pearsonr(predictions_flat, actual_flat)
    return correlation

# Try running trained model on test data
def test_trained_model(model, batch_size):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    #x_test = reshaped_gabor
    x_test = init_test_inputs
    y_test = test_outputs
    #y_test = test_outputs

    # Create tensor from training inputs and targets
    x_test_tensor, y_test_tensor = torch.from_numpy(x_test).float(), torch.from_numpy(y_test).float()
    test_data_tensor = torch.utils.data.TensorDataset(x_test_tensor, y_test_tensor)
    # Create dataloader from tensor
    test_dataloader = torch.utils.data.DataLoader(test_data_tensor, batch_size, shuffle=True)
    
    # Compute accuracy of model predictions
    predicted_results, actual_results = predicted_actual_values(model, test_dataloader)
    print('Predictions:')
    print(predicted_results[0][0:10])
    print('Actual:')
    print(actual_results[0][0:10])

    rmse = compute_rmse(predicted_results, actual_results)
    r2 = compute_r2_score(predicted_results, actual_results)
    pearson_correlation = compute_pearson_correlation(predicted_results, actual_results)
    
    print(f'Achieved RMSE: {rmse:.2f}')
    print(f'R-squared: {r2:.2f}')
    print(f'Pearson Correlation Coefficient: {pearson_correlation:.2f}')

In [21]:
# Load model

#model = trained_model

model = GaborWaveletPyramid(max_cycles_per_fov=16, num_orientations=8, num_phases=2, image_resolution=128)

#model.compute_and_save_gabor_responses(init_training_inputs[:1], "test_inputs_gabor1.pt")
model.load_precomputed_gabor_responses("test_inputs_gabor.pt")

print('Model loaded:', model_str) 
model.load_state_dict(torch.load('../trained_models/gabor/val_model_bestss.pth'))



# Test model
test_trained_model(model=model, batch_size=batch_size)

num_wavelets: 
64
Loaded Gabor responses from test_inputs_gabor.pt
Model loaded: AlexNet
Predictions:
[-0.17746465 -0.08122709  0.08913706 -0.12076683 -0.01345443 -0.0484757
 -0.04273073  0.06267595  0.06798709 -0.04870605]
Actual:
[-0.770339   -0.29902986  0.15490855 -0.08567116 -0.12409987  0.2816567
  0.0260905  -0.6763842   0.17463025  0.5590392 ]
Achieved RMSE: 0.49
R-squared: -0.13
Pearson Correlation Coefficient: 0.02


# NOTES FROM EXPERMENTING

3 droputs of 0.9 seemed to be promising but stagnated

3 dropouts of 0.85 stagnated after crossing 

3 dopouts of 0.75


Next time, try between 0.75 and 0.85
otherwise between 0.85 and 0.9

Implement feature layer output saved as pickle file to run model faster

Try with another dateset to confirm wehter it is just a "lorte dataset"

Save every 10 epoch as well as the best

**Evaluation metrics**

1. **RMSE**: Root Mean Squared Error is the square root of the mean of the squared errors. The RMSE is a good measure of how accurately the model predicts the response, and it is the most important criterion for fit if the main purpose of the model is prediction. A lower RMSE is better as it indicates a closer fit of the model to the data.
2. **R2**: R-squared values typically range from 0 to 1 and can be interpreted as the proportion of the variance in the dependent variable that is predictable from the independent variables. Negative values of R-squared indicate that the model fits the data worse than a horizontal hyperplane at the mean of the dependent variable. This suggests that the model is not capturing the variance of the data well and is performing poorly on this task. It could be due to an overfitted model, a wrong model choice, or irrelevant features.
3. **Pearson correlation**: Pearson correlation coefficient measures the linear correlation between two variables. The coefficient values range between -1 and 1. A value close to 1 implies that there is a strong positive correlation between the two variables. A value close to -1 implies that there is a strong negative correlation between the two variables. A value close to 0 implies that there is no linear correlation between the two variables.


## Visualize intermediate outputs