<a href="https://colab.research.google.com/github/ShreyashDhoot/TiPAI-TSPO/blob/main/Auditor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import matplotlib.pyplot as plt
from datasets import load_dataset
import PIL
from PIL import Image
import torchvision
from torchvision import transforms
import requests
import timm
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F

In [1]:
device = "cuda" if torch.cuda.is_available() else "cpu"

NameError: name 'torch' is not defined

In [None]:
dataset=load_dataset("ShreyashDhoot/KTO_trial",split='train',streaming=True)

In [None]:
###setting up CLIP model
model_name="resnet50_clip.openai"
model=timm.create_model(model_name,
                        pretrained=True,
                        features_only=True,
                        out_indices=[4] # <= gets output from layer 4
                        )
## puts model in eval mode
model.eval()
## shifts model on target device
model.to(device)
## gets the data configuration for pretrained models
data_config=timm.data.resolve_model_data_config(model)
## processes data according to config
processor = timm.data.create_transform(**data_config,is_training=False)

In [None]:
## 1. writing a function for getting training images
def get_batch_images(batch,device):
  images=[processor(img.convert('RGB')) for img in batch['image']]
  labels=torch.tensor(batch['label']).float().to(device).unsqueeze(1)
  return torch.stack(images).to(device),labels

In [None]:
# 1. defining the model
class BCE_pair_path(nn.Module):
  def __init__(self,input_classes=2048,num_classes=1):
    super().__init__()
    self.risk_conv=nn.Conv2d(input_classes,1,kernel_size=1)
    self.pool=nn.AdaptiveMaxPool2d((1,1))
    self.mlp=nn.Sequential(nn.Linear(1,16),
                           nn.ReLU(),
                           nn.Linear(16,1))
  def forward(self,x):
    risk_map=self.risk_conv(x)
    intermediate_logits=self.pool(risk_map)
    flattened_logits=torch.flatten(intermediate_logits,1)
    logits=self.mlp(flattened_logits)
    return logits

In [None]:
# 2. defining the BCE loss and making the hook
model_3=BCE_pair_path().to(device)
loss_fn_bce=nn.BCEWithLogitsLoss()
optimizer=torch.optim.Adam(params=model_3.parameters(),lr=0.001)

activations={}

def get_activations_hook(name):
  def hook(model,input,output):
    activations[name]=output
  return hook

handle=model_3.risk_conv.register_forward_hook(get_activations_hook('risk_map'))

In [None]:
# 3. defining pair wise loss
def loss_fn_pairwise(logits,labels):
  """
  L_pair = mean( log(1 + exp(-(S+i - S-j))) )
  """
  s_plus=logits[labels==1] #<=== selects logits whose true label is 1
  s_minus=logits[labels==0] #<===== selects logits whose true label is 0
  if s_plus.numel()==0 or s_minus.numel()==0:
    return torch.tensor(0.0,device=logits.device,requires_grad=True)

  diff_matrix= s_plus.unsqueeze(1)-s_minus.unsqueeze(0)

  loss_matrix = torch.log(1+torch.exp(-diff_matrix))

  return loss_matrix.mean()

In [None]:
# 4. lets define patch wise loss
def loss_fn_patch(risk_map,lables,k=11):
  batch_size=risk_map.shape[0]
  flat_map=risk_map.view(batch_size,-1)
  topk_values,_=torch.topk(flat_map,k,dim=1)

  mask_1 = (labels.view(-1) == 1)
  mask_0 = (labels.view(-1)==0)
  s_patch_plus=topk_values[mask_1]
  s_patch_minus=topk_values[mask_0]

  if s_patch_plus.numel()==0 or s_patch_minus.numel()==0:
    return torch.tensor(0.0, device=risk_map.device,requires_grad=True)

  s_plus_flat=s_patch_plus.reshape(-1)
  s_minus_flat=s_patch_minus.reshape(-1)

  diffs=s_plus_flat.unsqueeze(1)-s_minus_flat.unsqueeze(0)

  loss_patch=torch.log(1+torch.exp(-diffs)).mean()

  return loss_patch

In [None]:
# 5. plot heatmap
def plot_heatmap(image_tensor,risk_map_tensor):
  # get tensors back on cpu and change their orientation put channel at last
  img=image_tensor.cpu().permute(1,2,0).numpy()
  ## undo image normalization
  img =(img - img.min())/(img.max() - img.min())

  heatmap = torch.sigmoid(risk_map_tensor)
  heatmap = F.interpolate(heatmap.unsqueeze(0),size=(224,224),mode='bilinear')[0,0]
  heatmap=heatmap.detach().cpu().numpy()

  fig, (ax1,ax2)=plt.subplots(1,2,figsize=(10,5))
  ax1.imshow(img)
  ax1.set_title("original image")
  ax1.axis('off')

  ax2.imshow(img)
  im=ax2.imshow(heatmap,cmap="jet",alpha=0.5)
  ax2.set_title("Risk Heatmap")
  ax2.axis('off')

  plt.colorbar(im,ax=ax2)
  plt.show()

In [None]:
# Training loop
epochs=1
lambda_bce=1
lambda_pair=1
lambda_patch=1

for epoch in range(epochs):
  batched_dataset=dataset.batch(batch_size=32)
  epoch_loss=0
  for batch_number,batch in enumerate(tqdm(batched_dataset)):
    images,labels=get_batch_images(batch,device)
    with torch.no_grad():
      clip_logits = model(images)[0]

    optimizer.zero_grad()
    train_logits=model_3(clip_logits)
    current_heatmaps=activations['risk_map']
    # BCE loss
    l_bce=loss_fn_bce(train_logits,labels)
    #pairwise loss
    l_pair=loss_fn_pairwise(train_logits,labels)
    #patch wise loss
    l_patch=loss_fn_patch(current_heatmaps,labels,k=11)

    total_loss=(lambda_bce * l_bce)+(lambda_pair * l_pair)+(lambda_patch * l_patch)

    total_loss.backward()
    optimizer.step()

    epoch_loss += total_loss.item()

    if batch_number % 10 == 0:
      with torch.no_grad():
      # Calculate the "Gap" for monitoring
        pos_avg = train_logits[labels==1].mean().item() if (labels==1).any() else 0
        neg_avg = train_logits[labels==0].mean().item() if (labels==0).any() else 0
        gap = pos_avg - neg_avg
        print(f"Batch {batch_number} | BCE: {l_bce:.3f} | Pair: {l_pair:.3f} | Patch: {l_patch:.3f} | Gap: {gap:.3f}")
        print("="*80)

    if batch_number % 50 == 0:
      # Plot the first image in the batch and its corresponding heatmap
      plot_heatmap(images[0], current_heatmaps[0])
      print(f"Image label = {labels[0]}")
