# Q2

## 1

In [11]:


import math
import torch
import torch.nn.functional as F


def ctc_loss(log_probs : torch.Tensor, targets : torch.Tensor, input_lengths : torch.Tensor, target_lengths : torch.Tensor, blank : int = 0, reduction : str = 'mean', finfo_min_fp32: float = torch.finfo(torch.float32).min, finfo_min_fp16: float = torch.finfo(torch.float16).min, alignment : bool = False):
	input_time_size, batch_size = log_probs.shape[:2]
	B = torch.arange(batch_size, device = input_lengths.device)
	
	_targets = torch.cat([targets, targets[:, :1]], dim = -1)
	_targets = torch.stack([torch.full_like(_targets, blank), _targets], dim = -1).flatten(start_dim = -2)
	
	diff_labels = torch.cat([torch.as_tensor([[False, False]], device = targets.device).expand(batch_size, -1), _targets[:, 2:] != _targets[:, :-2]], dim = 1)
	
	
	zero_padding, zero = 2, torch.tensor(finfo_min_fp16 if log_probs.dtype == torch.float16 else finfo_min_fp32, device = log_probs.device, dtype = log_probs.dtype)
	log_probs_ = log_probs.gather(-1, _targets.expand(input_time_size, -1, -1))
	log_alpha = torch.full((input_time_size, batch_size, zero_padding + _targets.shape[-1]), zero, device = log_probs.device, dtype = log_probs.dtype)
	log_alpha[0, :, zero_padding + 0] = log_probs[0, :, blank]
	log_alpha[0, :, zero_padding + 1] = log_probs[0, B, _targets[:, 1]]
	for t in range(1, input_time_size):
		log_alpha[t, :, 2:] = log_probs_[t] + logadd(log_alpha[t - 1, :, 2:], log_alpha[t - 1, :, 1:-1], torch.where(diff_labels, log_alpha[t - 1, :, :-2], zero))

	l1l2 = log_alpha[input_lengths - 1, B].gather(-1, torch.stack([zero_padding + target_lengths * 2 - 1, zero_padding + target_lengths * 2], dim = -1)) 
	loss = -torch.logsumexp(l1l2, dim = -1)
	if reduction == 'mean':
		loss = (loss/target_lengths).mean()
	return loss

	

def ctc_alignment(log_probs : torch.Tensor, targets : torch.Tensor, input_lengths : torch.Tensor, target_lengths : torch.Tensor, blank: int = 0, finfo_min_fp32: float = torch.finfo(torch.float32).min, finfo_min_fp16: float = torch.finfo(torch.float16).min):
	input_time_size, batch_size = log_probs.shape[:2]
	B = torch.arange(batch_size, device = input_lengths.device)
	
	_targets = torch.cat([
		torch.stack([torch.full_like(targets, blank), targets], dim = -1).flatten(start_dim = -2),
		torch.full_like(targets[:, :1], blank)
	], dim = -1)
	diff_labels = torch.cat([
		torch.as_tensor([[False, False]], device = targets.device).expand(batch_size, -1),
		_targets[:, 2:] != _targets[:, :-2]
	], dim = 1)

	zero_padding, zero = 2, torch.tensor(finfo_min_fp16 if log_probs.dtype == torch.float16 else finfo_min_fp32, device = log_probs.device, dtype = log_probs.dtype)
	padded_t = zero_padding + _targets.shape[-1]
	log_alpha = torch.full((batch_size, padded_t), zero, device = log_probs.device, dtype = log_probs.dtype)
	log_alpha[:, zero_padding + 0] = log_probs[0, :, blank]
	log_alpha[:, zero_padding + 1] = log_probs[0, B, _targets[:, 1]]

	packmask = 0b11
	packnibbles = 4 # packnibbles = 1
	backpointers_shape = [len(log_probs), batch_size, int(math.ceil(padded_t / packnibbles))]
	backpointers = torch.zeros(backpointers_shape, device = log_probs.device, dtype = torch.uint8)
	backpointer = torch.zeros(backpointers_shape[1:], device = log_probs.device, dtype = torch.uint8)
	packshift = torch.tensor([[[6, 4, 2, 0]]], device = log_probs.device, dtype = torch.uint8)

	for t in range(1, input_time_size):
		prev = torch.stack([log_alpha[:, 2:], log_alpha[:, 1:-1], torch.where(diff_labels, log_alpha[:, :-2], zero)])
		log_alpha[:, zero_padding:] = log_probs[t].gather(-1, _targets) + prev.logsumexp(dim = 0)
		backpointer[:, zero_padding:(zero_padding + prev.shape[-1] )] = prev.argmax(dim = 0)
		torch.sum(backpointer.view(len(backpointer), -1, packnibbles) << packshift, dim = -1, out = backpointers[t]) 

	l1l2 = log_alpha.gather(-1, torch.stack([zero_padding + target_lengths * 2 - 1, zero_padding + target_lengths * 2], dim = -1))

	path = torch.zeros(input_time_size, batch_size, device = log_alpha.device, dtype = torch.long)
	path[input_lengths - 1, B] = zero_padding + target_lengths * 2 - 1 + l1l2.argmax(dim = -1)

	for t in range(input_time_size - 1, 0, -1):
		indices = path[t]
		backpointer = (backpointers[t].unsqueeze(-1) >> packshift).view_as(backpointer) 
		path[t - 1] += indices - backpointer.gather(-1, indices.unsqueeze(-1)).squeeze(-1).bitwise_and_(packmask)
	
	return torch.zeros_like(_targets, dtype = torch.int64).scatter_(-1, (path.t() - zero_padding).clamp(min = 0), torch.arange(input_time_size, device = log_alpha.device).expand(batch_size, -1))[:, 1::2]

def ctc_alignment_targets(log_probs, targets, input_lengths, target_lengths, blank = 0, ctc_loss = F.ctc_loss, retain_graph = True):
	loss = ctc_loss(log_probs, targets, input_lengths, target_lengths, blank = blank, reduction = 'sum')
	probs = log_probs.exp()
	grad_log_probs, = torch.autograd.grad(loss, log_probs, retain_graph = retain_graph)
	grad_logits = grad_log_probs - probs * grad_log_probs.sum(dim = -1, keepdim = True)
	temporal_mask = (torch.arange(len(log_probs), device = input_lengths.device, dtype = input_lengths.dtype).unsqueeze(1) < input_lengths.unsqueeze(0)).unsqueeze(-1)
	return (probs * temporal_mask - grad_logits).detach()

def logadd(x0, x1, x2):
	return torch.logsumexp(torch.stack([x0, x1, x2]), dim = 0)


class LogsumexpFunction(torch.autograd.function.Function):
	@staticmethod
	def forward(self, x0, x1, x2):
		m = torch.max(torch.max(x0, x1), x2)
		m = m.masked_fill_(torch.isinf(m), 0)
		e0 = (x0 - m).exp_()
		e1 = (x1 - m).exp_()
		e2 = (x2 - m).exp_()
		e = (e0 + e1).add_(e2).clamp_(min = 1e-16)
		self.save_for_backward(e0, e1, e2, e)
		return e.log_().add_(m)

	@staticmethod
	def backward(self, grad_output):
		e0, e1, e2, e = self.saved_tensors
		g = grad_output / e
		return g * e0, g * e1, g * e2


# 2

In [12]:
log_probs = torch.randn(50, 16, 20).log_softmax(2)
targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
input_lengths = torch.full((16,), 50, dtype=torch.long)
target_lengths = torch.randint(10,30,(16,), dtype=torch.long)
blank = 0

In [13]:
loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths,blank=0,reduction='mean')
print(loss)

tensor(6.0781)


In [14]:
custom_ctc = ctc_loss(log_probs, targets, input_lengths, target_lengths, blank = 0, reduction = 'mean')
print(custom_ctc)

tensor(6.0781)
