In [1]:
import sys
import os
sys.path.append(os.getcwd())
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

import matplotlib
cmap = matplotlib.cm.get_cmap('jet')


# init

In [2]:
from typing import Optional, Mapping, Callable, Union, Sequence,  Tuple
import numpy as np
import matplotlib.cm
import matplotlib.pyplot as plt


ColorType = Tuple[float, float, float, float]
DEFAULT_COLORMAP_NAME = 'jet'

ArrayLike = Union[Sequence[Union[float, int]], np.ndarray]
IntArrayLike = Union[Sequence[int], np.ndarray]


def get_colors_list(n: int, color_map=DEFAULT_COLORMAP_NAME) -> Sequence[ColorType]:
  """Generate list of n colors (RGBA format)."""
  cmap = matplotlib.cm.get_cmap(DEFAULT_COLORMAP_NAME)
  return [cmap(i) for i in range(0, 256, 256 // n)]


def get_phase2color(list_of_phases, color_map=DEFAULT_COLORMAP_NAME):
  list_of_colors = get_colors_list(n=len(list_of_phases), color_map=color_map)
  return {p: c for p, c in zip(list_of_phases, list_of_colors)}


def plot_labels_preds_timestamps(labels: ArrayLike,
                                 preds: ArrayLike,
                                 timestamps_indices: Optional[IntArrayLike] = np.array([]),
                                 phase2color: Optional[Mapping[int, ColorType]] = None):
  """Plots labels and preds bar with optional plotting of timestaps.

  Attributse:
    labels: An array of label, shape [num_frames].
    preds: An array of preds, shape [num_frames].
    timestamps_indices: Optional array of timestamps indices.
  """

  n = len(preds)
  x = np.arange(n)
  ss = [500] * n
  if not phase2color:
    phases = np.unique(np.concatenate([np.unique(np.asarray(preds)),
                                      np.unique(np.asarray(labels))]))
    phase2color = {p: c for p, c in zip(phases, get_colors_list(len(phases)))}
  if timestamps_indices.any():
    timestamps_labels = [labels[i] for i in timestamps_indices]
    plt.scatter(timestamps_indices, [1.1] * len(timestamps_indices), c=timestamps_labels, edgecolors="black", cmap='jet')
  plt.scatter(x, [1.2] * n, s=ss, marker="|", c=[phase2color[p] for p in preds], cmap='jet')
  plt.scatter(x, [1] * n, s=ss, marker="|", c=[phase2color[l] for l in labels], cmap='jet')

  if timestamps_indices.any():
    plt.annotate("Seeds",(1, 1.12))
  plt.annotate("GT",(1, 1.02))
  plt.annotate("Preds",(1, 1.22))

  plt.ylim([0.9, 1.3])
  plt.yticks([])
  plt.show()

def plot_labels_preds_timestamps(labels: ArrayLike,
                                 preds: ArrayLike,
                                 timestamps_indices: Optional[IntArrayLike] = np.array([]),
                                 phase2color: Optional[Mapping[int, ColorType]] = None):
  """Plots labels and preds bar with optional plotting of timestaps.

  Attributse:
    labels: An array of label, shape [num_frames].
    preds: An array of preds, shape [num_frames].
    timestamps_indices: Optional array of timestamps indices.
  """

  n = len(preds)
  x = np.arange(n)
  ss = [500] * n
  if not phase2color:
    phases = np.unique(np.concatenate([np.unique(np.asarray(preds)),
                                      np.unique(np.asarray(labels))]))
    phase2color = {p: c for p, c in zip(phases, get_colors_list(len(phases)))}
  if timestamps_indices.any():
    timestamps_labels = [labels[i] for i in timestamps_indices]
    plt.scatter(timestamps_indices, [1.1] * len(timestamps_indices), c=timestamps_labels, edgecolors="black", cmap='jet')
  plt.scatter(x, [1.2] * n, s=ss, marker="|", c=[phase2color[p] for p in preds], cmap='jet')
  plt.scatter(x, [1] * n, s=ss, marker="|", c=[phase2color[l] for l in labels], cmap='jet')

  if timestamps_indices.any():
    plt.annotate("Seeds",(1, 1.12))
  plt.annotate("GT",(1, 1.02))
  plt.annotate("Preds",(1, 1.22))

  plt.ylim([0.9, 1.3])
  plt.yticks([])
  plt.show()



def plot(mets_dict, title='', is_grid=True, is_legend=True):
    for k, v in mets_dict.items():
        plt.plot(v, label=str(k))
    plt.grid(is_grid)
    if is_legend:
        plt.legend()
    plt.title(title)
    plt.show()
    
SALADS_RW_MODELS = {'1': '/home/royhirsch/projects/research-il-tempseg/experiments/2201_ablation_50salads_scan_weights/conf=0.0750_lap=0.0750/split1/best_model.pt',
                '2': '/home/royhirsch/projects/research-il-tempseg/experiments/2201_ablation_50salads_scan_weights/conf=0.0750_lap=0.0750/split2/best_model.pt',
                '3': '/home/royhirsch/projects/research-il-tempseg/experiments/2201_ablation_50salads_scan_weights/conf=0.0750_lap=0.0750/split3/best_model.pt',
                '4': '/home/royhirsch/projects/research-il-tempseg/experiments/2201_ablation_50salads_scan_weights/conf=0.0750_lap=0.0750/split4/best_model.pt',
                '5': '/home/royhirsch/projects/research-il-tempseg/experiments/2201_ablation_50salads_scan_weights/conf=0.0750_lap=0.0750/split5/best_model.pt',}


GTEA_RW_MODELS = {'1': '/home/royhirsch/projects/research-il-tempseg/experiments/2201_gtea_large_scan/wd=True_neighbors=15.0_beta=30.0_gamma=0.0001_conf=0.1000_lap=0.1500/split1/best_model.pt',
                  '2': '/home/royhirsch/projects/research-il-tempseg/experiments/2201_gtea_large_scan/wd=True_neighbors=15.0_beta=30.0_gamma=0.0001_conf=0.1000_lap=0.1500/split2/best_model.pt',
                  '3': '/home/royhirsch/projects/research-il-tempseg/experiments/2201_gtea_large_scan/wd=True_neighbors=15.0_beta=30.0_gamma=0.0001_conf=0.1000_lap=0.1500/split3/best_model.pt',
                  '4': '/home/royhirsch/projects/research-il-tempseg/experiments/2201_gtea_large_scan/wd=True_neighbors=15.0_beta=30.0_gamma=0.0001_conf=0.1000_lap=0.1500/split4/best_model.pt',}


BREAKFAST_RW_MODELS = {'1': '/home/royhirsch/projects/research-il-tempseg/experiments/2001_breakfast_conf=0.1_smooth=0.15/split1/best_model.pt',
                       '2': '/home/royhirsch/projects/research-il-tempseg/experiments/2001_breakfast_conf=0.1_smooth=0.15/split2/best_model.pt',
                       '3': '/home/royhirsch/projects/research-il-tempseg/experiments/2001_breakfast_conf=0.1_smooth=0.15/split3/best_model.pt',
                       '4': '/home/royhirsch/projects/research-il-tempseg/experiments/2001_breakfast_conf=0.1_smooth=0.15/split4/best_model.pt',}


def get_rw_checkpoint(dataset, split):
    if dataset== '50salads':
        return SALADS_RW_MODELS[split]
    elif dataset == 'breakfast':
        return BREAKFAST_RW_MODELS[split]
    elif dataset == 'gtea':
        return GTEA_RW_MODELS[split]
    else:
        raise ValueError


SALADS_FB_MODEL = {'1': '/home/royhirsch/projects/research-il-tempseg/timestamps/models/50salads/margin_map_both2023-01-23_19-39-44_split_1/epoch-50.model',
                   '2': '/home/royhirsch/projects/research-il-tempseg/timestamps/models/50salads/margin_map_both2023-01-24_11-15-41_split_2/epoch-50.model',
                   '3': '/home/royhirsch/projects/research-il-tempseg/timestamps/models/50salads/margin_map_both2023-01-24_11-15-51_split_3/epoch-50.model',
                   '4': '/home/royhirsch/projects/research-il-tempseg/timestamps/models/50salads/margin_map_both2023-01-24_11-16-36_split_4/epoch-50.model',
                   '5': '/home/royhirsch/projects/research-il-tempseg/timestamps/models/50salads/margin_map_both2023-01-25_08-58-21_split_5/epoch-50.model',}


GTEA_FB_MODEL = {'1': '/home/royhirsch/projects/research-il-tempseg/timestamps/models/gtea/margin_map_both2023-01-25_14-38-07_split_1/epoch-50.model',
                 '2': '/home/royhirsch/projects/research-il-tempseg/timestamps/models/gtea/margin_map_both2023-01-26_07-35-25_split_2/epoch-50.model',
                 '3': '/home/royhirsch/projects/research-il-tempseg/timestamps/models/gtea/margin_map_both2023-01-26_07-35-31_split_3/epoch-50.model',
                 '4': '/home/royhirsch/projects/research-il-tempseg/timestamps/models/gtea/margin_map_both2023-01-26_07-36-02_split_4/epoch-50.model'}


BREAKFAST_FB_MODEL = {'1': '/home/royhirsch/projects/research-il-tempseg/timestamps/models/breakfast/margin_map_both2023-01-23_21-28-42_split_1/epoch-50.model',
                      '2': '/home/royhirsch/projects/research-il-tempseg/timestamps/models/breakfast/margin_map_both2023-02-02_15-37-45_split_2_hd=64/epoch-50.model',
                      '3': '/home/royhirsch/projects/research-il-tempseg/timestamps/models/breakfast/margin_map_both2023-02-02_15-38-02_split_3_hd=64/epoch-50.model',
                      '4': '/home/royhirsch/projects/research-il-tempseg/timestamps/models/breakfast/margin_map_both2023-02-02_15-38-13_split_4_hd=64/epoch-50.model'}

def get_fb_checkpoint(dataset, split):
    if dataset== '50salads':
        return SALADS_FB_MODEL[split]
    elif dataset == 'breakfast':
        return BREAKFAST_FB_MODEL[split]
    elif dataset == 'gtea':
        return GTEA_FB_MODEL[split]
    else:
        raise ValueError

SALADS_UVAST_MODEL = {'1': '/home/royhirsch/projects/research-il-tempseg/myuvast/50salads_split1_stage2.model',
                      '2': '/home/royhirsch/projects/research-il-tempseg/myuvast/50salads_split2_stage2.model',
                      '3': '/home/royhirsch/projects/research-il-tempseg/myuvast/50salads_split3_stage2.model',
                      '4': '/home/royhirsch/projects/research-il-tempseg/myuvast/50salads_split4_stage2.model',
                      '5': '/home/royhirsch/projects/research-il-tempseg/myuvast/50salads_split5_stage2.model'}


GTEA_UVAST_MODEL = {'1': '/home/royhirsch/projects/research-il-tempseg/myuvast/gtea_split1_stage2.model',
                    '2': '/home/royhirsch/projects/research-il-tempseg/myuvast/gtea_split2_stage2.model',
                    '3': '/home/royhirsch/projects/research-il-tempseg/myuvast/gtea_split3_stage2.model',
                    '4': '/home/royhirsch/projects/research-il-tempseg/myuvast/gtea_split4_stage2.model'}


BREAKFAST_UVAST_MODEL = {'1': '/home/royhirsch/projects/research-il-tempseg/myuvast/breakfast_split1_stage2.model',
                         '2': '/home/royhirsch/projects/research-il-tempseg/myuvast/breakfast_split2_stage2.model',
                         '3': '/home/royhirsch/projects/research-il-tempseg/myuvast/breakfast_split3_stage2.model',
                         '4': '/home/royhirsch/projects/research-il-tempseg/myuvast/breakfast_split4_stage2.model'}

def get_uvast_checkpoint(dataset, split):
    if dataset== '50salads':
        return SALADS_UVAST_MODEL[split]
    elif dataset == 'breakfast':
        return BREAKFAST_UVAST_MODEL[split]
    elif dataset == 'gtea':
        return GTEA_UVAST_MODEL[split]
    else:
        raise ValueError

def get_base_rw_params(dataset): # todo
    if dataset == 'gtea':
        return {'similarity_method': 'euclidean', 'sharpening_method': 'exp', 'beta': 30, 'average_method': 'min', 'num_neighbors': 15, 'gamma': 0.001, 'smooth': 10}
    elif dataset == '50salads':
        return {'similarity_method': 'euclidean', 'sharpening_method': 'exp', 'beta': 30, 'average_method': 'min', 'num_neighbors': 15, 'gamma': 0.001, 'smooth': 10}
    elif dataset == 'breakfast':
        return {'similarity_method': 'euclidean', 'sharpening_method': 'exp', 'beta': 30, 'average_method': 'min', 'num_neighbors': 15, 'gamma': 0.001, 'smooth': 20}
    else:
        raise ValueError


# Plots

In [3]:
params = {'batch_size': 1,
          'dataset': 'breakfast',
          'split': '1',
          'gpu_num': 0,
          'limit': 100,
          }

rw_params = get_base_rw_params(params['dataset'])

rw_check_path = get_rw_checkpoint(params['dataset'], params['split'])
fb_check_path = get_fb_checkpoint(params['dataset'], params['split'])
uvast_check_path = get_uvast_checkpoint(params['dataset'], params['split'])


In [4]:
from rws.config import get_and_modify_config
from rws.data import get_dataloaders
from rws.model import MultiStageModel
from evaluation import MetricLoger
from stand_alone.random_walk import predict_random_walk_with_prior, get_sparse_prior_from_timestamps

config = get_and_modify_config(params)
train_dl, val_dl = get_dataloaders(config)
val_dl.shuffle = True

rw_model = torch.load(rw_check_path, map_location=config.device)
rw_model = rw_model.to(config.device)
rw_model.eval()

fb_model = MultiStageModel(4, 10, 64, 2048, config.num_classes)
fb_model.load_state_dict(torch.load(fb_check_path, map_location=config.device))
fb_model = fb_model.to(config.device)
fb_model.eval()

res = []
with torch.no_grad():
    for i, batch in enumerate(val_dl):
        video_names, featuers, labels, _, masks, timestamps = batch
        featuers = featuers.to(config.device)
        masks = masks.to(config.device)

        rw_middle_pred, rw_predictions = rw_model(featuers, masks)
        fb_middle_pred, fb_predictions = fb_model(featuers, masks)

        _, rw_predicted = torch.max(rw_predictions[-1].data, 1)
        _, fb_predicted = torch.max(fb_predictions[-1].data, 1)
        rw_predicted = rw_predicted.detach().cpu().numpy()
        fb_predicted = fb_predicted.detach().cpu().numpy()

        featuers = featuers.detach().cpu().numpy()
        labels = labels.detach().cpu().numpy()

        rw_i3d_preds = get_sparse_prior_from_timestamps(featuers[0].T, labels[0], timestamps[0], config.num_classes, rw_params)
        rw_pp_preds = predict_random_walk_with_prior(frames=featuers[0].T,
                                        prior=torch.nn.functional.softmax(rw_predictions[-1][0], 0).detach().cpu().numpy(),
                                        sharpening_method=rw_params['sharpening_method'],
                                        similarity_method=rw_params['similarity_method'],
                                        beta=rw_params['beta'],
                                        average_method=rw_params['average_method'],
                                        num_neighbors=rw_params['num_neighbors'],
                                        gamma=rw_params['gamma'])
        rw_i3d_mets = MetricLoger()
        rw_i3d_mets.update(labels[0], rw_i3d_preds)

        rw_mets = MetricLoger()
        rw_mets.update(labels[0], rw_predicted[0])

        rw_pp_mets = MetricLoger()
        rw_pp_mets.update(labels[0], rw_pp_preds)

        fb_mets = MetricLoger()
        fb_mets.update(labels[0], fb_predicted[0])



        res.append({'labels': labels[0],
                    'timestamps': timestamps[0],
                    'features': featuers[0].T,
                    'rw_i3d_preds': np.asarray(rw_i3d_preds),
                    'rw_i3d_mets':rw_i3d_mets.calc(),

                    'rw_preds': rw_predicted[0],
                    'rw_mets': rw_mets.calc(),
                    'rw_logits': rw_predictions[-1][0].T.detach().cpu().numpy(),

                    'rw_pp_preds': np.asarray(rw_pp_preds),
                    'rw_pp_mets': rw_pp_mets.calc(),

                    'fb_preds': fb_predicted[0],
                    'fb_mets': fb_mets.calc(),
                    'fb_logits': fb_predictions[-1][0].T.detach().cpu().numpy(),
                    })
        if params['limit'] and i > params['limit']:
            break

FileNotFoundError: [Errno 2] No such file or directory: '/home/royhirsch/projects/research-il-tempseg/timestamps/models/breakfast/margin_map_both2023-01-23_21-28-42_split_1/epoch-50.model'

In [6]:
import copy
import scipy

def plot_probs(label, logits, tmp=1):
  
  logits_softmax = scipy.special.softmax(logits * tmp, 0)
  logits_softmax = copy.deepcopy(logits_softmax)
  example_num_classes = np.unique(np.concatenate([np.unique(label), np.unique(logits_softmax.argmax(0))]))
  logits_softmax = logits_softmax[example_num_classes, :]
  sn2class_numer = {i: j for i, j in enumerate(example_num_classes)}
  num2color = {i: cmap(j) for i, j in zip(example_num_classes, range(0, 256, 256 // len(example_num_classes)))}

  fig = plt.figure(figsize=(12, 4), facecolor='white', dpi=300)
  f, axs = plt.subplots(6, 1, gridspec_kw={'height_ratios': [1, 1, 1, 1, 1.5, 2]}, figsize=(12,4))

  n = len(logits_softmax)
  x = np.arange(0,1,1/n).round(3)[:-1]
  bounderies = np.where(np.abs(label[1:] - label[:-1]) != 0)[0]

  for i, l in enumerate(logits_softmax):
    # ax = plt.subplot(len(sn2class_numer) + 2, 1, i+1)
    ax = axs[i]
    ax.set_facecolor('white')
    ax.plot(l, color=num2color[sn2class_numer[i]], linewidth=2)
    ax.fill_between(np.arange(logits_softmax.shape[1]), l, where=l>0, color=num2color[sn2class_numer[i]], alpha=0.5)
    ax.axis('off')

    # ax.set_xlim([0,1])
    ax.set_xticks([]) 
    ax.set_yticks([])

  s = 250
  ax = axs[i + 1]
  # ax = plt.subplot(len(sn2class_numer) + 2, 1, i+2)
  ax.scatter(np.arange(len(label)),
            y=[0.5]*len(label),
            marker='|',
            s=s,
            c=[num2color[sn2class_numer[i]] for i in logits_softmax.argmax(0)],
            linewidth=linewidth)

  ax.scatter(np.arange(len(label)),
          y=[0.2]*len(label),
          marker='|',
          s=s,
          c=[num2color[l] for l in label],
          linewidth=linewidth)
  ax.set_facecolor('white')
  ax.set_ylim([0.04,0.63])
  ax.set_yticks([])
  ax.set_xticks([])
  # ax.set_title('Predictions', y=1.0, loc='left')
  ax.axis('off')

  ax = axs[i + 2]
  # ax = plt.subplot(len(sn2class_numer) + 2, 1, i+3)
  # logits_softmax = scipy.special.softmax(logits, 0)
  ent = scipy.stats.entropy(logits_softmax)
  ent_max = np.max(ent)
  ent_min = np.min(ent)
  # ent = (ent - ent_min) / (ent_max - ent_min)
  ent /= ent_max
  ent = ent ** 3
  n = logits.shape[1]
  x = np.arange(n)
  ax.scatter(x, [1] * n, s=99600, marker="|", c=ent, cmap='jet', linewidth=linewidth)
  # ax.plot(1. - ent)
  ax.set_ylim([0., 1.])
  ax.axis('off')

  fig.tight_layout(h_pad=0)

linewidth = 3
n = 95
s = 20
e = 359

# n = 37
# s = 400
# e = 1400

logits = res[n]['rw_logits'].T[:,s:e]
label = res[n]['labels'][s:e]
plot_probs(label, logits, is_rw=True)

logits = res[n]['fb_logits'].T[:,s:e]
label = res[n]['labels'][s:e]
plot_probs(label, logits)

'/home/royhirsch/projects/research-il-tempseg'