In [1]:
import sqlite3

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, PackedSequence
from torch import Tensor
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F

from climb_conversion import ClimbsFeatureArray, ClimbsFeatureScaler
from hold_classifier import UNetHoldClassifierLogits, train_unet_hold_classifier_logits
from simple_diffusion import ClimbDDPM, Noiser, zero_com

# CLASSIFIER_SAVE_PATH = "data/weights/unet-hold-classifier.pth"

# climbs = ClimbsFeatureArray()
# dataset = climbs.get_features(roles=True)

# hold_classifier = UNetHoldClassifierLogits()
# train_unet_hold_classifier_logits(
#     hold_classifier,
#     dataset,
#     epochs=20,
#     batch_size=1024,
#     num_workers=2,
#     save_path= CLASSIFIER_SAVE_PATH
# )

In [None]:
DB_PATH = "data/storage.db"
SCALER_WEIGHTS_PATH = 'data/weights/climbs-feature-scaler.joblib'
DDPM_WEIGHTS_PATH = 'data/weights/simple-diffusion-large.pth'
HC_WEIGHTS_PATH = 'data/weights/unet-hold-classifier.pth'
WALL_ID = 'wall-0a877f13d8e5'

GRADE_TO_DIFF = {
    "font": {
        "4a": 10, "4b": 11, "4c": 12,
        "5a": 13, "5b": 14, "5c": 15,
        "6a": 16, "6a+": 17, "6b": 18, "6b+": 19,
        "6c": 20, "6c+": 21,
        "7a": 22, "7a+": 23, "7b": 24, "7b+": 25,
        "7c": 26, "7c+": 27,
        "8a": 28, "8a+": 29, "8b": 30, "8b+": 31,
        "8c": 32, "8c+": 33,
    },
    "v_grade": {
        "V0-": 10, "V0": 11, "V0+": 12,
        "V1": 13, "V1+": 14, "V2": 15,
        "V3": 16, "V3+": 17, "V4": 18, "V4+": 19,
        "V5": 20, "V5+": 21, "V6": 22, "V6+": 22.5,
        "V7": 23, "V7+": 23.5, "V8": 24, "V8+": 25,
        "V9": 26, "V9+": 26.5, "V10": 27, "V10+": 27.5,
        "V11": 28, "V11+": 28.5, "V12": 29, "V12+": 29.5,
        "V13": 30, "V13+": 30.5, "V14": 31, "V14+": 31.5,
        "V15": 32, "V15+": 32.5, "V16": 33,
    },
}


class ClimbDDPMGenerator():
    def __init__(
            self,
            wall_id: str,
            scaler: ClimbsFeatureScaler,
            ddpm: ClimbDDPM,
            hold_classifier: UNetHoldClassifierLogits,
        ):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.scaler = scaler
        self.ddpm = ddpm
        self.hold_classifier = hold_classifier
        self.timesteps = 100

        with sqlite3.connect(DB_PATH) as conn:
            holds = pd.read_sql_query("SELECT hold_index, x, y, pull_x, pull_y, useability, is_foot, wall_id FROM holds WHERE wall_id = ?",conn,params=(wall_id,))
            scaled_holds = self.scaler.transform_hold_features(holds, to_df=True)
            self.holds_manifold = torch.tensor(scaled_holds[['x','y','pull_x','pull_y']].values, dtype=torch.float32)
            self.holds_lookup = scaled_holds['hold_index'].values
        
        self.holds_lookup = np.concatenate([self.holds_lookup, np.array([-1, -1, -1, -1])])
        
        self.holds_manifold = torch.cat([
            self.holds_manifold,
            torch.tensor(
                [[-2.0, 0.0, -2.0, 0.0],
                [2.0, 0.0, -2.0, 0.0],
                [-2.0, 0.0, 2.0, 0.0],
                [2.0, 0.0, 2.0, 0.0]],dtype=torch.float32)
            ],dim=0)

    def _build_cond_tensor(self, n, grade, diff_scale, angle):
        diff = GRADE_TO_DIFF[diff_scale][grade]
        df_cond = pd.DataFrame({
            "grade": [diff]*n,
            "quality": [2.9]*n,
            "ascents": [100]*n,
            "angle": [angle]*n
        })

        cond = self.scaler.transform_climb_features(df_cond).T
        return torch.tensor(cond, device=self.device, dtype=torch.float32)
    
    def _project_onto_manifold(self, gen_climbs: Tensor, offset_manifold: Tensor)-> Tensor:
        """
            Project each generated hold to its nearest neighbor on the hold manifold.
            
            Args:
                gen_climbs: (B, S, H) predicted clean holds
                return_indices: (boolean) Whether to return the hold indices or hold feature coordinates
            Returns:
                projected: (B, S, H) each hold snapped to nearest manifold point
        """
        B, S, H = gen_climbs.shape
        flat_climbs = gen_climbs.reshape(-1,H)
        dists = torch.cdist(flat_climbs, offset_manifold)
        idx = dists.argmin(dim=1)
        return self.holds_manifold[idx].reshape(B, S, -1)
        
    def _project_onto_indices(self, gen_climbs: Tensor, cond_t: Tensor, offset_manifold: Tensor) -> list[list[int]]:
        """Project climb onto the final hold indices (and remove null holds)"""
        
        B, S, H = gen_climbs.shape

        roles = torch.argmax(self.hold_classifier(gen_climbs, cond_t), dim=2).detach().numpy()

        flat_climbs = gen_climbs.reshape(-1,H)
        dists = torch.cdist(flat_climbs, offset_manifold)
        idx = dists.argmin(dim=1)
        holds = self.holds_lookup[idx]
        holds = holds.reshape(B, S)

        print(type(holds),type(roles))
        
        # Mask null holds to be role 4
        is_null = (holds == -1)
        roles[is_null] = 4
        
        # Concatenate indices and roles
        climbs = np.stack([holds, roles], axis=2)
        
        # Convert climbs into list[np.array] filtering on role != 4 
        climbs = [c[c[:,1] != 4].tolist() for c in climbs]
        
        return climbs
    
    def _projection_strength(self, t: Tensor, t_start_projection: float = 0.8):
        """Calculate the weight to assign to the projected holds based on the timestep."""
        a = (t_start_projection-t)/t_start_projection
        strength = 1 - torch.cos(a*torch.pi/2)
        return torch.where(t > t_start_projection, torch.zeros_like(t), strength).unsqueeze(2)
    
    @torch.no_grad()
    def generate(
        self,
        n: int = 1 ,
        angle: int = 45,
        grade: str = 'V4',
        diff_scale: str = 'v_grade',
        deterministic: bool = False
    )->list[list[int]]:
        """
        Generate a climb or batch of climbs with the given conditions using the standard DDPM iterative denoising process.
        
        :param n: Number of climbs to generate
        :type n: int
        :param angle: Angle of the wall
        :type angle: int
        :param grade: Desired difficulty (V-grade)
        :type grade: int | None
        :return: A Tensor containing the denoised generated climbs as hold sets.
        :rtype: Tensor
        """
        cond_t = self._build_cond_tensor(n, grade, diff_scale, angle)
        x_t = torch.randn((n, 20, 4), device=self.device)
        noisy = x_t.clone()
        t_tensor = torch.ones((n,1), device=self.device)
        
        # Randomly offset the holds-manifold to allow for climbs to be generated at different x-coordinates around the wall.
        x_offset = np.random.randn()
        offset_manifold = self.holds_manifold.clone()
        offset_manifold[:,0] += x_offset*0.1

        for t in range(0, self.timesteps):
            print('.',end='')

            gen_climbs = self.ddpm(noisy, cond_t, t_tensor)

            alpha_p = self._projection_strength(t_tensor)
            projected_climbs = self._project_onto_manifold(gen_climbs, offset_manifold)
            print(projected_climbs.shape, alpha_p.shape)
            gen_climbs = alpha_p*(projected_climbs) + (1-alpha_p)*(gen_climbs)
            
            t_tensor -= 1.0/self.timesteps
            noisy = self.ddpm.forward_diffusion(gen_climbs, t_tensor, x_t if deterministic else torch.randn_like(x_t))
        
        return self._project_onto_indices(gen_climbs, cond_t, offset_manifold)

ddpm = ClimbDDPM(
    model=Noiser(),
    weights_path=DDPM_WEIGHTS_PATH,
    timesteps=100,
)
scaler = ClimbsFeatureScaler(
    weights_path=SCALER_WEIGHTS_PATH
)
hold_classifier = UNetHoldClassifierLogits(
    weights_path=HC_WEIGHTS_PATH
)
generator = ClimbDDPMGenerator(
    wall_id=WALL_ID,
    scaler=scaler,
    ddpm=ddpm,
    hold_classifier=hold_classifier
)
climbs = generator.generate(n=2)
for c in climbs:
    print(c)

  state_dict = torch.load(filepath, map_location=map_loc)


.

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x2 and 4x128)