In [1]:
import sqlite3

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, PackedSequence
from torch import Tensor
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F

from climb_conversion import ClimbsFeatureArray, ClimbsFeatureScaler
from hold_classifier import UNetHoldClassifierLogits
from simple_diffusion import ClimbDDPM, Noiser, zero_com

climbs = ClimbsFeatureArray()
dataset = climbs.get_features(limit = 500, roles=True)
dataset[0]

Initializing ClimbsFeatureArray...
ClimbsFeatureArray initialized! 132057 unique climbs added!


(tensor([[ 0.0000,  0.0000,  0.1117, -0.5647],
         [-0.0870, -0.5405, -0.0031, -0.0611],
         [ 0.0000, -0.5405, -0.0031, -0.0611],
         [-0.1739, -0.2162,  0.0357,  0.0094],
         [-0.1304,  0.1622, -0.3267, -0.5005],
         [-0.0435,  0.2703, -0.0112, -0.4194],
         [ 0.1304,  0.4865,  0.2634, -0.5250],
         [ 0.0000,  0.6486, -0.2856, -0.8135],
         [ 0.3913,  0.2703, -0.5112, -0.3729],
         [ 0.0435,  0.8108,  0.2071, -0.8304],
         [ 0.0000,  0.8649,  0.0413, -0.8501],
         [ 0.1304,  1.1351,  0.5235, -0.7133],
         [-0.1739,  1.2973,  0.0599, -0.8491],
         [-2.0000,  0.0000, -2.0000,  0.0000],
         [-2.0000,  0.0000, -2.0000,  0.0000],
         [-2.0000,  0.0000, -2.0000,  0.0000],
         [-2.0000,  0.0000, -2.0000,  0.0000],
         [-2.0000,  0.0000, -2.0000,  0.0000],
         [-2.0000,  0.0000, -2.0000,  0.0000],
         [-2.0000,  0.0000, -2.0000,  0.0000]]),
 tensor([-0.4788, -0.9626,  0.4949,  0.1429]),
 tensor([[1

In [None]:
class ResidualBlock1D_V2(nn.Module):
    def __init__(self, in_channels, out_channels, cond_dim, padding=1):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=padding)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=padding)
        self.norm1 = nn.GroupNorm(8, out_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.act = nn.SiLU()

        self.cond_proj = nn.Linear(cond_dim, out_channels*2)
        self.shortcut = nn.Conv1d(in_channels,out_channels,1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, cond):
        h = self.conv1(x)
        h = self.norm1(h)

        gamma, beta = self.cond_proj(cond).unsqueeze(-1).chunk(2, dim=1)
        h = h*(1+gamma) + beta
        h = self.act(h)

        h = self.conv2(h)
        h = self.norm2(h)
        h = self.act(h)

        return h + self.shortcut(x)

class UNetHoldClassifierLogits(nn.Module):
    def __init__(
        self,
        in_features_dim: int = 4,
        in_cond_dim: int = 4,
        out_dim: int = 5,
        hidden_dim: int = 128,
        n_layers: int = 2,
    ):
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.cond_emb = nn.Sequential(
            nn.Linear(in_cond_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

        self.init_conv = ResidualBlock1D_V2(in_features_dim, hidden_dim, hidden_dim)

        self.down_blocks = nn.ModuleList([ResidualBlock1D_V2(hidden_dim*(i+1), hidden_dim*(i+2), hidden_dim) for i in range(n_layers)])
        self.up_blocks = nn.ModuleList([ResidualBlock1D_V2(hidden_dim*(i+1), hidden_dim*(i), hidden_dim) for i in range(n_layers,0,-1)])

        self.head = nn.Conv1d(hidden_dim, out_dim, 1)

    def forward(self, x, cond):

        x = zero_com(x, 2)

        cond_emb = self.cond_emb(cond)
        h_emb = self.init_conv(x.transpose(1,2), cond_emb)

        residuals = []
        for layer in self.down_blocks:
            residuals.append(h_emb)
            h_emb = layer(h_emb, cond_emb)
        
        for layer in self.up_blocks:
            resid = residuals.pop()
            h_emb = resid + layer(h_emb, cond_emb)
        
        h_out = self.head(h_emb).transpose(1,2)

        return h_out

def train_unet_hold_classifier_logits(
    hold_classifier: nn.Module,
    dataset: TensorDataset,
    epochs: int = 100,
    batch_size: int = 128,
    num_workers: int = 0,
    save_path: str | None = None,
    save_on_best: bool = True,
    torch_compile: bool = False
) -> tuple[nn.Module, list[float]]:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Model Set-Up
    hold_classifier.to(device)
    hold_classifier.train()
    if torch_compile:
        hold_classifier = torch.compile(hold_classifier)
    optimizer = torch.optim.Adam(params = hold_classifier.parameters())

    n_params = sum([p.numel() for p in hold_classifier.parameters()])

    # DataLoader Set-Up
    batches = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=num_workers
    )
    
    epoch_losses = []
    with tqdm(range(epochs)) as pbar:
        for epoch in pbar:
            batch_losses = []
            for x, cond, target_role_probs in batches:
                x, cond = x.to(device), cond.to(device)
                optimizer.zero_grad()

                pred_logits = hold_classifier(x, cond)

                loss = F.cross_entropy(pred_logits.transpose(1,2), torch.argmax(target_role_probs,dim=2))

                loss.backward()
                optimizer.step()

                batch_losses.append(loss.item())
            mean_batch_loss = sum(batch_losses)/len(batch_losses)
            epoch_losses.append(mean_batch_loss)
            info_str = f"Epoch: {epoch}: Avg Batch Loss: {mean_batch_loss:.3f}, Min Batch Loss: {min(epoch_losses):.3f}. {len(batch_losses)} batches, {n_params} params"
            if save_path and save_on_best and (min(epoch_losses)==mean_batch_loss):
                pbar.set_postfix_str(f"{info_str}. New best mean batch loss! Saving hold classifier at {save_path}...")
                torch.save(hold_classifier.state_dict(), save_path)
            else:
                pbar.set_postfix_str(info_str)

    if save_path:
        print(f"Saving hold classifier at {save_path}...")
        torch.save(hold_classifier.state_dict(), save_path)
    
    
    # Plot training results
    fig, ax = plt.subplots()

    ax.plot(list(range(len(epoch_losses))), epoch_losses)
    ax.set_yscale('log')
    ax.set_title(f'Mean Batch-Loss per Epoch, U-Net Hold Classifier ({n_params} params)')
    plt.show()

    return hold_classifier, epoch_losses

CLASSIFIER_SAVE_PATH = "data/weights/unet-hold-classifier"

train_unet_hold_classifier_logits(
    UNetHoldClassifierLogits(),
    dataset,
    save_path=CLASSIFIER_SAVE_PATH
)

 85%|████████▌ | 85/100 [02:42<00:29,  1.95s/it, Epoch: 84: Avg Batch Loss: 0.000, Min Batch Loss: 0.000. 7 batches, 2307589 params. New best mean batch loss! Saving hold classifier at data/weights/unet-hold-classifier...]

In [None]:
sum([p.numel() for p in UNetHoldClassifierLogits(n_layers=3).parameters()])

In [None]:
WALL_ID = 'wall-443c15cd12e0'


class EGNNHoldClassifier(nn.Module):
    def __init__(self, weights_path: str | None = None, input_dim: int = 8, hidden_dim: int = 256, num_layers: int = 1):
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout = 0.3,
            bidirectional=True,
            device=self.device,
            batch_first=True
        )

        self.classification_head = nn.Sequential(
            nn.Linear(hidden_dim*2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)
        )

        self.loss_func = nn.BCEWithLogitsLoss()

        if weights_path:
            self.load_state_dict(torch.load(weights_path, map_location = self.device))
    
    def loss(self, pred_roles: Tensor, true_roles: Tensor):
        """Get the loss from the model's predictions, via cross-entropy loss."""
        return self.loss_func(pred_roles, true_roles)

    
    def forward(self, holds_cond: PackedSequence | Tensor)-> Tensor:
        """Run the forward pass. Predicts the roles for a given (possibly batched) set of holds, given (possibly batched) wall conditions."""

        _, (hs, cs) = self.lstm(holds_cond)

        lstm_final_state = torch.cat([hs[-1],hs[-2]], dim=1)
        
        sf_logits = self.classification_head(lstm_final_state)

        return sf_logits

In [43]:


GRADE_TO_DIFF = {
    "font": {
        "4a": 10, "4b": 11, "4c": 12,
        "5a": 13, "5b": 14, "5c": 15,
        "6a": 16, "6a+": 17, "6b": 18, "6b+": 19,
        "6c": 20, "6c+": 21,
        "7a": 22, "7a+": 23, "7b": 24, "7b+": 25,
        "7c": 26, "7c+": 27,
        "8a": 28, "8a+": 29, "8b": 30, "8b+": 31,
        "8c": 32, "8c+": 33,
    },
    "v_grade": {
        "V0-": 10, "V0": 11, "V0+": 12,
        "V1": 13, "V1+": 14, "V2": 15,
        "V3": 16, "V3+": 17, "V4": 18, "V4+": 19,
        "V5": 20, "V5+": 21, "V6": 22, "V6+": 22.5,
        "V7": 23, "V7+": 23.5, "V8": 24, "V8+": 25,
        "V9": 26, "V9+": 26.5, "V10": 27, "V10+": 27.5,
        "V11": 28, "V11+": 28.5, "V12": 29, "V12+": 29.5,
        "V13": 30, "V13+": 30.5, "V14": 31, "V14+": 31.5,
        "V15": 32, "V15+": 32.5, "V16": 33,
    },
}

class ClimbDDPMGenerator():
    def __init__(
            self,
            wall_id: str,
            scaler: ClimbsFeatureScaler,
            ddpm: ClimbDDPM,
            role_classifer: HoldClassifier | None = None
        ):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.scaler = scaler
        self.ddpm = ddpm
        self.timesteps = 100

        if role_classifer:
            self.role_classifier = role_classifer

        with sqlite3.connect(DB_PATH) as conn:
            holds = pd.read_sql_query("SELECT hold_index, x, y, pull_x, pull_y, useability, is_foot, wall_id FROM holds WHERE wall_id = ?",conn,params=(wall_id,))
            scaled_holds = self.scaler.transform_hold_features(holds, to_df=True)
            self.holds_manifold = torch.tensor(scaled_holds[['x','y','pull_x','pull_y']].values, dtype=torch.float32)
            self.holds_lookup = scaled_holds['hold_index'].values
        
        self.holds_lookup = np.concatenate([self.holds_lookup, np.array([-1, -1, -1, -1])])
        
        self.holds_manifold = torch.cat([
            self.holds_manifold,
            torch.tensor(
                [[-2.0, 0.0, -2.0, 0.0],
                [2.0, 0.0, -2.0, 0.0],
                [-2.0, 0.0, 2.0, 0.0],
                [2.0, 0.0, 2.0, 0.0]],dtype=torch.float32)
            ],dim=0)

    def _build_cond_tensor(self, n, grade, diff_scale, angle):
        diff = GRADE_TO_DIFF[diff_scale][grade]
        df_cond = pd.DataFrame({
            "grade": [diff]*n,
            "quality": [2.9]*n,
            "ascents": [100]*n,
            "angle": [angle]*n
        })

        cond = self.scaler.transform_climb_features(df_cond).T
        return torch.tensor(cond, device=self.device, dtype=torch.float32)
    
    def _project_onto_manifold(self, gen_climbs: Tensor, offset_manifold: Tensor)-> Tensor:
        """
            Project each generated hold to its nearest neighbor on the hold manifold.
            
            Args:
                gen_climbs: (B, S, H) predicted clean holds
                return_indices: (boolean) Whether to return the hold indices or hold feature coordinates
            Returns:
                projected: (B, S, H) each hold snapped to nearest manifold point
        """
        B, S, H = gen_climbs.shape
        flat_climbs = gen_climbs.reshape(-1,H)
        dists = torch.cdist(flat_climbs, offset_manifold)
        idx = dists.argmin(dim=1)
        return self.holds_manifold[idx].reshape(B, S, -1)
        
    def _project_onto_indices(self, gen_climbs: Tensor, offset_manifold: Tensor):
        """Project climb onto the final hold indices (and remove null holds)"""
        
        B, S, H = gen_climbs.shape

        climbs = []
        for gen_climb in gen_climbs:
            flat_climb = gen_climb.reshape(-1,H)
            dists = torch.cdist(flat_climb, offset_manifold)
            idx = dists.argmin(dim=1)
            idx = idx.detach().numpy()
            holds = self.holds_lookup[idx]
            climb = list(set(holds[holds > 0].tolist()))
            climbs.append(climb)  
        return climbs
    
    def _projection_strength(self, t: Tensor, t_start_projection: float = 0.3):
        """Calculate the weight to assign to the projected holds based on the timestep."""
        a = (t_start_projection-t)/t_start_projection
        strength = 1 - torch.cos(a*torch.pi/2)
        return torch.where(t > t_start_projection, torch.zeros_like(t), strength)
    
    @torch.no_grad()
    def generate(
        self,
        n: int = 1 ,
        angle: int = 45,
        grade: str = 'V4',
        diff_scale: str = 'v_grade',
        deterministic: bool = False,
        classify_holds: bool = False
    )->list[list[int]]:
        """
        Generate a climb or batch of climbs with the given conditions using the standard DDPM iterative denoising process.
        
        :param n: Number of climbs to generate
        :type n: int
        :param angle: Angle of the wall
        :type angle: int
        :param grade: Desired difficulty (V-grade)
        :type grade: int | None
        :return: A Tensor containing the denoised generated climbs as hold sets.
        :rtype: Tensor
        """
        cond_t = self._build_cond_tensor(n, grade, diff_scale, angle)
        x_t = torch.randn((n, 20, 4), device=self.device)
        noisy = x_t.clone()
        t_tensor = torch.ones((n,1), device=self.device)
        
        # Randomly offset the holds-manifold to allow for climbs to be generated at different x-coordinates around the wall.
        x_offset = np.random.randn()*0.2
        offset_manifold = self.holds_manifold.clone()
        offset_manifold[:,0] += x_offset

        for t in range(0, self.timesteps):
            print('.',end='')

            gen_climbs = self.ddpm(noisy, cond_t, t_tensor)

            alpha_p = self._projection_strength(t_tensor)
            projected_climbs = self._project_onto_manifold(gen_climbs, offset_manifold)
            gen_climbs = alpha_p*(projected_climbs) + (1-alpha_p)*(gen_climbs)
            
            t_tensor -= 1.0/self.timesteps
            noisy = self.ddpm.forward_diffusion(gen_climbs, t_tensor, x_t if deterministic else torch.randn_like(x_t))
        
        hold_indices_out = self._project_onto_indices(gen_climbs, offset_manifold)
        
        if classify_holds:
            S = len(hold_indices_out[0])
            # Convert gen_climbs into an input dataset for the HoldClassifier
            gen_climbs[:,:,0] -= x_offset
            input_seq = gen_climbs.reshape(-1,4)
            input_seq = input_seq[:S,:]
            input_seq = torch.cat([input_seq, cond_t.expand(S,-1)], dim=1).unsqueeze(0)
            
            # Run the classifier
            sf_logits = self.role_classifier(input_seq)

            # Assign hold roles, assigning start and finish roles based on HoldClassifier's best guess.
            dual_start = sf_logits[0][0] > 0
            dual_fin = sf_logits[0][1] > 0

            roles = [[idx,2] for idx in hold_indices_out[0]]
            roles[0][1] = 0
            roles[-1][1] = 1
            if dual_start:
                roles[1][1] = 0
            if dual_fin:
                roles[-2][1] = 1

            return roles
        else:
            return hold_indices_out

model = ClimbDDPM(
    model=Noiser(),
    weights_path=DDPM_WEIGHTS_PATH,
    timesteps=100,
)
scaler = ClimbsFeatureScaler(
    weights_path=SCALER_WEIGHTS_PATH
)
role_classifier = HoldClassifier(LSTM_WEIGHTS_PATH)
generator = ClimbDDPMGenerator(
    wall_id=WALL_ID,
    scaler=scaler,
    ddpm=model,
    role_classifer=role_classifier
)
generator.generate(classify_holds=True)

  state_dict = torch.load(filepath, map_location=map_loc)
  self.load_state_dict(torch.load(weights_path, map_location = self.device))


....................................................................................................

[[640, 0],
 [897, 2],
 [741, 2],
 [935, 2],
 [874, 2],
 [939, 2],
 [748, 2],
 [817, 2],
 [788, 1]]