In [None]:
#default_exp loss 

In [1]:
import os 
os.chdir('/content/drive/My Drive/Semi Supervised Learning/Self-Supervised-Learning-fastai2')

In [None]:
!sh initialise.sh

Collecting fastai2
[?25l  Downloading https://files.pythonhosted.org/packages/cc/50/2f37212be57b7ee3e9c947336f75a66724468b21a3ca68734eaa82e7ebf3/fastai2-0.0.30-py3-none-any.whl (179kB)
[K     |█▉                              | 10kB 13.8MB/s eta 0:00:01[K     |███▋                            | 20kB 2.9MB/s eta 0:00:01[K     |█████▌                          | 30kB 3.7MB/s eta 0:00:01[K     |███████▎                        | 40kB 4.0MB/s eta 0:00:01[K     |█████████▏                      | 51kB 3.2MB/s eta 0:00:01[K     |███████████                     | 61kB 3.7MB/s eta 0:00:01[K     |████████████▉                   | 71kB 4.1MB/s eta 0:00:01[K     |██████████████▋                 | 81kB 4.1MB/s eta 0:00:01[K     |████████████████▍               | 92kB 4.5MB/s eta 0:00:01[K     |██████████████████▎             | 102kB 4.5MB/s eta 0:00:01[K     |████████████████████            | 112kB 4.5MB/s eta 0:00:01[K     |██████████████████████          | 122kB 4.5MB/s eta 0

In [None]:
#export 
from ssl_fastai2.imports import *

In [None]:
#export 
class DotProduct(nn.Module):
  def forward(self, x, y):
    return torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims = 2)
  
class NTXnetLoss(nn.Module): # Normalised Temperature Scaled Cross Entropy
  def __init__(self, batch_size, similarity_type = 'cosine', temp = 0.1, use_softmax = False):
    super().__init__()
    self.use_softmax = use_softmax
    self.temp = temp

    if torch.cuda.is_available():
      self.device = torch.device('cuda')
    else:
      self.device = torch.device('cpu')

    self.batch_size = batch_size 
    self.neg_mask = self._get_neg_mask().to(self.device) 
    self.similarity = self._get_similarity(similarity_type.lower()) 
    self.criterion = nn.CrossEntropyLoss(reduction = 'sum')
    if use_softmax:
      self.softmax = nn.Softmax(dim = -1)

  def _get_similarity(self, sim_type):
    if sim_type == 'cosine': 
      return self.cosine_sim
    else:
      return DotProduct()

  def cosine_sim(self, x, y):
    return nn.CosineSimilarity(dim = -1)(x.unsqueeze(1), y.unsqueeze(0))

  @property 
  def _T_batch(self): return 2*self.batch_size

  def _get_neg_mask(self):
    diagonal = np.eye(self._T_batch)
    upper_diag = np.eye(self._T_batch, k = self.batch_size)
    lower_diag = np.eye(self._T_batch, k = -self.batch_size)
    mask = torch.from_numpy((diagonal + upper_diag + lower_diag))
    return (1 - mask).type(torch.bool)
  
  def forward(self, zi, zj):
    feature_rep = torch.cat([zi, zj], dim = 0)

    similarity_matrix = self.similarity(feature_rep, feature_rep)
    similarity_matrix = similarity_matrix/self.temp 

    negatives = similarity_matrix[self.neg_mask].view(self._T_batch, -1)

    l_pos = torch.diag(similarity_matrix, -self.batch_size)
    r_pos = torch.diag(similarity_matrix, self.batch_size) 
    positives = torch.cat([l_pos, r_pos]).view(self._T_batch, -1)

    labels = torch.zeros(self._T_batch).to(self.device).long()
    logits = torch.cat([positives, negatives], dim = 1)
    if self.use_softmax:
      logits = self.softmax(logits)
    loss = self.criterion(logits, labels)

    return loss/self._T_batch 

In [None]:
#export  
class BaseSSLLoss(nn.Module):
  def __init__(self, model, with_negatives = False, global_loss_func = None, branch_loss_func = None,  
               glob_loss_weight = 1., branch_loss_weight = 1.):
    super().__init__()      
    self.model = model
    self.with_negatives = with_negatives 
    self.glob_loss_weight = glob_loss_weight; self.branch_loss_weight = branch_loss_weight 
    self.global_loss_func = global_loss_func 
    if branch_loss_func:
      self.branch_loss_func = branch_loss_func 

In [None]:
#export 
class CGDLoss(BaseSSLLoss):
  def forward(self, *yb):
    '''Depending on Application targ could be positives or the augmentations of the same image '''

    if self.with_negatives:
      pred, targ, negatives, labels = yb 
    else:
      pred, targ, labels = yb

    targ = self.model(targ, glob = False)
    if self.with_negatives:
      neg = self.model(negatives, glob = False)

    global_loss = self.glob_loss_weight * self.global_loss_func(pred[0], labels)

    if self.with_negatives:
      target_loss = self.branch_loss_weight * self.branch_loss_func(pred[1], targ, neg)

    return global_loss + target_loss     