In [None]:
import torch
import torch.nn.functional as F
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Load an image
image = Image.open("/home/phamtam/Pictures/images.jpeg")

# Define the transformations
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Apply the transformations to the image
image_tensor = transform(image)

# Add a batch dimension to the image tensor
image_tensor = image_tensor.unsqueeze(0)

# Create a grid of coordinates
theta = torch.tensor([
    [1, 0, 0],
    [0, 1, 0]
], dtype=torch.float)
theta = theta.unsqueeze(0)
grid = F.affine_grid(theta, image_tensor.size())

# Apply grid sampling to the image
output = F.grid_sample(image_tensor, grid, align_corners=True)

# Convert the output tensor to a PIL image
output_image = transforms.ToPILImage()(output.squeeze(0))

# Display the original and transformed images
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(image)
axs[0].set_title("Original Image")
axs[0].axis("off")
axs[1].imshow(output_image)
axs[1].set_title("Transformed Image")
axs[1].axis("off")
plt.tight_layout()
plt.show()

In [1]:
import os
import time
import copy
from functools import partial
from typing import Optional, Callable, Any
from collections import OrderedDict
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from einops import rearrange, repeat
from timm.models.layers import DropPath, trunc_normal_
from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count
DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"

# import selective scan, 
try:
    import selective_scan_cuda_oflex
except Exception as e:
    ...
try:
    import selective_scan_cuda_core
except Exception as e:
    ...

try:
    import selective_scan_cuda
except Exception as e:
    ...

# fvcore flops

def gather_by_angle(tensor, angle):
    B, C, H, W = tensor.size()
    rad_angle = math.radians(angle)

    # step sizes in x and y
    step_x = math.cos(rad_angle)
    step_y = math.sin(rad_angle)
    # create grid of indice
    indices_x = torch.arange(0,W, device = tensor.device)
    indices_y = torch.arange(0,H, device = tensor.device)
    grid_x, grid_y = torch.meshgrid(indices_x, indices_y, indexing = 'xy')
    
    # starting position
    start_x = (grid_y * step_x).round().long()
    start_y = (grid_y * step_y).round().long()

    # create the gathering indices
    gather_indices = (start_x.unsqueeze(-1) + indices_x).clamp(0,W-1)
    gather_indices = gather_indices.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)

    # gatehr the elements along the specified angle
    gathered = tensor.gather(3, gather_indices)

    return gathered.transpose(-1,-2).reshape(B,C,H*W)

def scatter_by_angle(tensor_flat, original_shape, angle):
    B, C, H, W = original_shape
    rad_angle = math.radians(angle)
    
    # Compute the step sizes in the x and y directions based on the angle
    step_x = math.cos(rad_angle)
    step_y = math.sin(rad_angle)
    
    # Create a grid of indices for scattering
    indices_x = torch.arange(0, W, device=tensor_flat.device)
    indices_y = torch.arange(0, H, device=tensor_flat.device)
    grid_x, grid_y = torch.meshgrid(indices_x, indices_y, indexing='xy')
    
    # Compute the starting positions for each row
    start_x = (grid_y * step_x).round().long()
    start_y = (grid_y * step_y).round().long()
    
    # Create the scattering indices
    scatter_indices = (start_x.unsqueeze(-1) + indices_x).clamp(0, W-1)
    scatter_indices = scatter_indices.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
    
    # Create an empty tensor to store the scattered result
    result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype)
    
    # Scatter the flattened tensor back to the original shape
    result_tensor.scatter_(3, scatter_indices, tensor_flat.reshape(B, C, H, W))
    
    return result_tensor

class SelectiveDirection(nn.Module):
    def __init__(
        self,
        in_channels,
        n_groups,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.n_groups = n_groups
        self.n_group_channels = self.in_channels // self.n_groups
        self.conv_offset = nn.Sequential(
            nn.Conv2d(self.n_group_channels, self.n_group_channels, 3, 1, 1, groups=self.n_group_channels),
            nn.LayerNorm(self.n_group_channels),
            nn.GELU(),
            nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False)
        )

    def forward(self, x):
        B, C, H, W = x.shape
        x_groups = x.reshape(B, self.n_groups, self.n_group_channels, H, W)

        # Divide the image into 9 sub-sections
        sub_h, sub_w = H // 3, W // 3
        sub_sections = []
        for i in range(3):
            for j in range(3):
                if i == 1 and j == 1:
                    continue  # Skip the center sub-section
                sub_section = x_groups[..., i*sub_h:(i+1)*sub_h, j*sub_w:(j+1)*sub_w]
                sub_sections.append(sub_section)

        # Compute attention points for each sub-section
        points = []
        for sub_section in sub_sections:
            sub_section_flat = sub_section.reshape(B * self.n_groups, self.n_group_channels, sub_h, sub_w)
            sub_section_points = self.conv_offset(sub_section_flat)
            sub_section_points = sub_section_points.reshape(B, self.n_groups, 2, 1, 1)
            sub_section_points = sub_section_points.mean(dim=1, keepdim=True)
            points.append(sub_section_points)

        points = torch.cat(points, dim=1)
        points = points.reshape(B, 8, 2, 1, 1)

        # Compute angles from the center point to the learned points
        angles = self.create_angle(points)

        return points, angles

    def create_angle(self, points):
        B, _, _, _, _ = points.shape
        center_x, center_y = 0.5, 0.5  # Assuming normalized coordinates

        angles = []
        for i in range(B):
            batch_points = points[i].squeeze()  # Shape: (8, 2)
            batch_angles = []
            for j in range(8):
                point_x, point_y = batch_points[j]
                angle = math.atan2(point_y - center_y, point_x - center_x)
                batch_angles.append(angle)
            angles.append(batch_angles)

        angles = torch.tensor(angles, device=points.device)
        angles = angles.reshape(B, 8, 1, 1, 1)

        return angles

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image, ImageDraw
import math
import numpy as np

in_channels = 512
n_groups = 32
model = SelectiveDirection(in_channels, n_groups)

# Load and preprocess the input image
image_path = "/home/phamtam/Pictures/images.jpeg"
image = Image.open(image_path).convert("RGB")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = transform(image).unsqueeze(0)  # Add batch dimension
print(input_tensor.shape)
print(np.array(image).shape)
# Forward pass through the SelectiveDirection module
with torch.no_grad():
    points, angles = model(input_tensor)

# Visualize the learned attention points and angles on the input image
def visualize_points_and_angles(image, points, angles):
    draw = ImageDraw.Draw(image)
    image_width, image_height = image.size
    center_x, center_y = image_width // 2, image_height // 2

    for i in range(points.shape[1]):
        point_x = int(points[0, i, 0, 0, 0].item() * image_width)
        point_y = int(points[0, i, 1, 0, 0].item() * image_height)
        angle = angles[0, i, 0, 0, 0].item()

        # Draw the attention point
        draw.ellipse((point_x - 2, point_y - 2, point_x + 2, point_y + 2), fill="red")

        # Draw the angle line
        angle_x = center_x + 50 * math.cos(angle)
        angle_y = center_y + 50 * math.sin(angle)
        draw.line((center_x, center_y, angle_x, angle_y), fill="blue", width=1)

    return image

# Visualize the points and angles on the input image
visualized_image = visualize_points_and_angles(image, points, angles)
visualized_image.show()

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image, ImageDraw
import math
from einops import rearrange

class LayerNormProxy(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.norm = nn.LayerNorm(dim)

    def forward(self,x):
        x = rearrange(x,'b c h m  -> b h m c')
        x = self.norm(x)
        x = rearrange(x,'b h m c -> b c h m')

        return x

# Define the SelectiveDirection module
class SelectiveDirection(nn.Module):
    def __init__(
        self,
        in_channels,
        n_groups,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.n_groups = n_groups
        self.n_group_channels = self.in_channels // self.n_groups
        self.conv_offset = nn.Sequential(
            nn.Conv2d(self.n_group_channels, self.n_group_channels, 3, 1, 1, groups=self.n_group_channels),
            LayerNormProxy(self.n_group_channels),
            nn.GELU(),
            nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False)
        )

    def forward(self, x):
        B, C, H, W = x.shape
        x_groups = x.reshape(B, self.n_groups, self.n_group_channels, H, W)

        # Divide the image into 9 sub-sections
        sub_h, sub_w = H // 3, W // 3
        sub_sections = []
        for i in range(3):
            for j in range(3):
                if i == 1 and j == 1:
                    continue  # Skip the center sub-section
                sub_section = x_groups[..., i*sub_h:(i+1)*sub_h, j*sub_w:(j+1)*sub_w]
                sub_sections.append(sub_section)

        # Compute attention points for each sub-section
        points = []
        for sub_section in sub_sections:
            sub_section_flat = sub_section.reshape(B * self.n_groups, self.n_group_channels, sub_h, sub_w)
            sub_section_points = self.conv_offset(sub_section_flat)
            sub_section_points = sub_section_points.reshape(B, self.n_groups, 2, 1, 1)
            sub_section_points = sub_section_points.mean(dim=1, keepdim=True)
            points.append(sub_section_points)

        points = torch.cat(points, dim=1)
        points = points.reshape(B, 8, 2, 1, 1)

        # Compute angles from the center point to the learned points
        angles = self.create_angle(points)

        return points, angles

    def create_angle(self, points):
        B, _, _, _, _ = points.shape
        center_x, center_y = 0.5, 0.5  # Assuming normalized coordinates

        angles = []
        for i in range(B):
            batch_points = points[i].squeeze()  # Shape: (8, 2)
            batch_angles = []
            for j in range(8):
                point_x, point_y = batch_points[j]
                angle = math.atan2(point_y - center_y, point_x - center_x)
                batch_angles.append(angle)
            angles.append(batch_angles)

        angles = torch.tensor(angles, device=points.device)
        angles = angles.reshape(B, 8, 1, 1, 1)

        return angles
# Create an instance of the SelectiveDirection module
in_channels = 3
n_groups = 1
model = SelectiveDirection(in_channels, n_groups)

# Load and preprocess the input image
image_path = "/home/phamtam/Pictures/images.jpeg"
image = Image.open(image_path).convert("RGB")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = transform(image).unsqueeze(0)  # Add batch dimension

# Forward pass through the SelectiveDirection module
with torch.no_grad():
    points, angles = model(input_tensor)

def visualize_points_and_angles(image, points, angles):
    draw = ImageDraw.Draw(image)
    image_width, image_height = image.size
    center_x, center_y = image_width // 2, image_height // 2

    for i in range(points.shape[1]):
        point_x = int(points[0, i, 0, 0, 0].item() * image_width)
        point_y = int(points[0, i, 1, 0, 0].item() * image_height)
        angle = angles[0, i, 0, 0, 0].item()

        # Draw the attention point
        draw.ellipse((point_x - 2, point_y - 2, point_x + 2, point_y + 2), fill="red")

        # Draw the angle line
        angle_x = center_x + 50 * math.cos(angle)
        angle_y = center_y + 50 * math.sin(angle)
        draw.line((center_x, center_y, angle_x, angle_y), fill="blue", width=1)

    return image

# Visualize the points and angles on the input image
visualized_image = visualize_points_and_angles(image, points, angles)
visualized_image.show()

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image, ImageDraw
import math

# Define the SelectiveDirection module
class SelectiveDirection(nn.Module):
    def __init__(
        self,
        in_channels,
        n_groups,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.n_groups = n_groups
        self.n_group_channels = self.in_channels // self.n_groups
        self.conv_offset = nn.Sequential(
            nn.Conv2d(self.in_channels, self.in_channels, 3, 1, 1, groups=self.n_groups),
            nn.GELU(),
            nn.Conv2d(self.in_channels, 2, 1, 1, 0, bias=False)
        )

    def forward(self, x):
        B, C, H, W = x.shape

        # Divide the image into 9 sub-sections
        sub_h, sub_w = H // 3, W // 3
        sub_sections = []
        for i in range(3):
            for j in range(3):
                if i == 1 and j == 1:
                    continue  # Skip the center sub-section
                sub_section = x[:, :, i*sub_h:(i+1)*sub_h, j*sub_w:(j+1)*sub_w]
                sub_sections.append(sub_section)

        # Compute attention points for each sub-section
        points = []
        for sub_section in sub_sections:
            sub_section_points = self.conv_offset(sub_section)
            sub_section_points = sub_section_points.mean(dim=(2, 3)).unsqueeze(-1).unsqueeze(-1)
            points.append(sub_section_points)

        points = torch.cat(points, dim=1)
        points = points.reshape(B, 8, 2, 1, 1)
        print(points.shape)
        # Compute angles from the center point to the learned points
        angles = self.create_angle(points)

        return points, angles

    def create_angle(self, points):
        B, _, _, _, _ = points.shape
        center_x, center_y = 0.5, 0.5  # Assuming normalized coordinates

        angles = []
        for i in range(B):
            batch_points = points[i].squeeze()  # Shape: (8, 2)
            batch_angles = []
            for j in range(8):
                point_x, point_y = batch_points[j]
                angle = math.atan2(point_y - center_y, point_x - center_x)
                batch_angles.append(angle)
            angles.append(batch_angles)

        angles = torch.tensor(angles, device=points.device)
        angles = angles.reshape(B, 8, 1, 1, 1)

        return angles

# Create an instance of the SelectiveDirection module
in_channels = 3
n_groups = 1
model = SelectiveDirection(in_channels, n_groups)

# Load and preprocess the input image
image_path = "/home/phamtam/Pictures/images.jpeg"
image = Image.open(image_path).convert("RGB")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = transform(image).unsqueeze(0)  # Add batch dimension

# Forward pass through the SelectiveDirection module
with torch.no_grad():
    points, angles = model(input_tensor)

def visualize_points_and_angles(image, points, angles):
    draw = ImageDraw.Draw(image)
    image_width, image_height = image.size
    center_x, center_y = image_width // 2, image_height // 2

    for i in range(points.shape[1]):
        point_x = int(points[0, i, 0, 0, 0].item() * image_width)
        point_y = int(points[0, i, 1, 0, 0].item() * image_height)
        angle = angles[0, i, 0, 0, 0].item()

        # Draw the attention point
        draw.ellipse((point_x - 2, point_y - 2, point_x + 2, point_y + 2), fill="red")

        # Draw the angle line
        angle_x = center_x + 50 * math.cos(angle)
        angle_y = center_y + 50 * math.sin(angle)
        draw.line((center_x, center_y, angle_x, angle_y), fill="blue", width=1)

    return image

# Visualize the points and angles on the input image
visualized_image = visualize_points_and_angles(image, points, angles)
visualized_image.show()